添加 class 信息到 keras 网络

Add class information to keras network

我想弄清楚如何将我的数据集的标签信息与生成对抗网络一起使用。我正在尝试使用 can be found here 的条件 GAN 的以下实现。我的数据集包含两个不同的图像域(真实对象和草图),具有共同的 class 信息(椅子、树、橙子等)。我选择了这个实现,它只考虑两个不同的域作为对应的不同 "classes"(训练样本 X 对应真实图像,而目标样本 y 对应草图图像)。

有没有办法修改我的代码并在我的整个架构中考虑 class 信息(椅子、树等)?我实际上希望我的鉴别器预测我从生成器生成的图像是否属于特定的 class 而不仅仅是它们是否真实。实际上,在当前架构下,系统会学习在所有情况下创建相似的草图。

更新:鉴别器returns一个大小为1x7x7的张量然后y_truey_pred都通过一个在计算损失之前展平图层:

def discriminator_loss(y_true, y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.concatenate([K.ones_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])),K.zeros_like(K.flatten(y_pred[:BATCH_SIZE,:,:,:])) ]) ), axis=-1)

以及判别器对生成器的损失函数:

def discriminator_on_generator_loss(y_true,y_pred):
     BATCH_SIZE=100
     return K.mean(K.binary_crossentropy(K.flatten(y_pred), K.ones_like(K.flatten(y_pred))), axis=-1)

此外,我对输出1层的判别器模型的修改:

model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
#model.add(Activation('sigmoid'))

现在鉴别器输出1层。如何相应地修改上述损失函数?对于 n_classes = 6 + 一个 class 来预测真假对,我应该用 7 个而不是 1 个吗?

您应该修改鉴别器模型,使其具有两个输出或 "n_classes + 1" 输出。

警告:我在鉴别器的定义中没有看到它输出 'true/false',我看到它输出图像...

某处它应该包含 GlobalMaxPooling2DGlobalAveragePooling2D
最后还有一个或多个 Dense 层用于 class 化。

如果告诉true/false,最后一个Dense应该有1个单位。
否则 n_classes + 1 单位。

所以,你的鉴别器的结尾应该是这样的

...GlobalMaxPooling2D()...
...Dense(someHidden,...)...
...Dense(n_classes+1,...)...

判别器现在将输出 n_classes 加上一个 "true/fake" 符号(你将无法在那里使用 "categorical")或者甚至是一个 "fake class"(然后你将其他 classes 归零并使用分类)

您生成的草图应该连同目标一起传递给鉴别器,该目标将是假 class 与另一个 class 的串联。

选项 1 - 使用 "true/fake" 符号。 (不要使用 "categorical_crossentropy")

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = originalClasses

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

选项 2 - 使用 "fake class"(可以使用 "categorical_crossentropy"):

#true sketches into discriminator:
fakeClass = np.zeros((total_samples,))
sketchClass = originalClasses

targetClassTrue = np.concatenate([fakeClass,sketchClass], axis=-1)

#fake sketches into discriminator:
fakeClass = np.ones((total_fake_sketches))
sketchClass = np.zeros((total_fake_sketches, n_classes))

targetClassFake = np.concatenate([fakeClass,sketchClass], axis=-1)

现在将所有内容连接成一个目标数组(相对于输入草图)

更新训练方法

对于这种训练方法,你的损失函数应该是以下之一:

  • discriminator.compile(loss='binary_crossentropy', optimizer=....)
  • discriminator.compile(loss='categorical_crossentropy', optimizer=...)

代码:

for epoch in range(100):
    print("Epoch is", epoch)
    print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))

    for index in range(int(X_train.shape[0]/BATCH_SIZE)):

        #names:
            #images -> initial images, not changed    
            #sketches -> generated + true sketches    
            #classes -> your classification for the images    
            #isGenerated -> the output of your discriminator telling whether the passed sketches are fake

        batchSlice = slice(index*BATCH_SIZE,(index+1)*BATCH_SIZE)
        trueImages = X_train[batchSlice]

        trueSketches = Y_train[batchSlice] 
        trueClasses = originalClasses[batchSlice]
        trueIsGenerated = np.zeros((len(trueImages),)) #discriminator telling whether the sketch is fake or true (generated images = 1)
        trueEndTargets = np.concatenate([trueIsGenerated,trueClasses],axis=1)

        fakeSketches = generator.predict(trueImages)
        fakeClasses = originalClasses[batchSlize]             #if option 1 -> telling class + isGenerated - use "binary_crossentropy"
        fakeClasses = np.zeros((len(fakeSketches),n_classes)) #if option 2 -> telling if generated is an individual class - use "categorical_crossentropy"    
        fakeIsGenerated = np.ones((len(fakeSketches),))
        fakeEndTargets = np.concatenate([fakeIsGenerated, fakeClasses], axis=1)

        allSketches = np.concatenate([trueSketches,fakeSketches],axis=0)            
        allEndTargets = np.concatenate([trueEndTargets,fakeEndTargets],axis=0)

        d_loss = discriminator.train_on_batch(allSketches, allEndTargets)

        pred_temp = discriminator.predict(allSketches)
        #print(np.shape(pred_temp))
        print("batch %d d_loss : %f" % (index, d_loss))

        ##WARNING## In previous keras versions, "trainable" only takes effect if you compile the models. 
            #you should have the "discriminator" and the "discriminator_on_generator" with these set at the creation of the models and never change it again   

        discriminator.trainable = False
        g_loss = discriminator_on_generator.train_on_batch(trueImages, trueEndTargets)
        discriminator.trainable = True


        print("batch %d g_loss : %f" % (index, g_loss[1]))
        if index % 20 == 0:
            generator.save_weights('generator', True)
            discriminator.save_weights('discriminator', True)

正确编译模型

当您创建 "discriminator" 和 "discriminator_on_generator" 时:

discriminator.trainable = True
for l in discriminator.layers:
    l.trainable = True


discriminator.compile(.....)

for l in discriminator_on_generator.layer[firstDiscriminatorLayer:]:
    l.trainable = False

discriminator_on_generator.compile(....)

建议的解决方案

重用 repository you shared 中的代码,这里有一些建议的修改,以沿着生成器和鉴别器训练 classifier(它们的架构和其他损失保持不变):

from keras import backend as K
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D

def lenet_classifier_model(nb_classes):
    # Snipped by Fabien Tanc - https://www.kaggle.com/ftence/keras-cnn-inspired-by-lenet-5
    # Replace with your favorite classifier...
    model = Sequential()
    model.add(Convolution2D(12, 5, 5, activation='relu', input_shape=in_shape, init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Convolution2D(25, 5, 5, activation='relu', init='he_normal'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(180, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(100, activation='relu', init='he_normal'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_classes, activation='softmax', init='he_normal'))

def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
    inputs = Input((IN_CH, img_cols, img_rows))
    x_generator = generator(inputs)

    merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
    discriminator.trainable = False
    x_discriminator = discriminator(merged)

    classifier.trainable = False
    x_classifier = classifier(x_generator)

    model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])

    return model


def train(BATCH_SIZE):
    (X_train, Y_train, LABEL_train) = get_data('train')  # replace with your data here
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    Y_train = (Y_train.astype(np.float32) - 127.5) / 127.5
    discriminator = discriminator_model()
    generator = generator_model()
    classifier = lenet_classifier_model(6)
    generator.summary()
    discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
        generator, discriminator, classifier)
    d_optim = Adagrad(lr=0.005)
    g_optim = Adagrad(lr=0.005)
    generator.compile(loss='mse', optimizer="rmsprop")
    discriminator_and_classifier_on_generator.compile(
        loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
        optimizer="rmsprop")
    discriminator.trainable = True
    discriminator.compile(loss=discriminator_loss, optimizer="rmsprop")
    classifier.trainable = True
    classifier.compile(loss="categorical_crossentropy", optimizer="rmsprop")

    for epoch in range(100):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0] / BATCH_SIZE))
        for index in range(int(X_train.shape[0] / BATCH_SIZE)):
            image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
            label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here

            generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image * 127.5 + 127.5
                image = np.swapaxes(image, 0, 2)
                cv2.imwrite(str(epoch) + "_" + str(index) + ".png", image)
                # Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")

            # Training D:
            real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),
                                        axis=1)
            fake_pairs = np.concatenate(
                (X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
            X = np.concatenate((real_pairs, fake_pairs))
            y = np.zeros((20, 1, 64, 64))  # [1] * BATCH_SIZE + [0] * BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            discriminator.trainable = False

            # Training C:
            c_loss = classifier.train_on_batch(image_batch, label_batch)
            print("batch %d c_loss : %f" % (index, c_loss))
            classifier.trainable = False

            # Train G:
            g_loss = discriminator_and_classifier_on_generator.train_on_batch(
                X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
                [image_batch, np.ones((10, 1, 64, 64)), label_batch])
            discriminator.trainable = True
            classifier.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss[1]))
            if index % 20 == 0:
                generator.save_weights('generator', True)
                discriminator.save_weights('discriminator', True)

理论细节

我认为对于条件 GAN 的工作原理以及鉴别器在此类方案中的作用存在一些误解。

鉴别器的作用

在 GAN 训练 [4] 的最小-最大游戏中,判别器 D 与生成器 G(您真正关心的网络)进行对抗,因此 D的审查,G在输出真实结果方面变得更好。

为此,D 被训练来区分真实样本和来自 G 的样本;而 G 被训练通过生成符合目标分布的真实结果/结果来愚弄 D

Note: in the case of conditional GANs, i.e. GANs mapping an input sample from one domain A (e.g. real picture) to another domain B (e.g. sketch), D is usually fed with the pairs of samples stacked together and has to discriminate "real" pairs (input sample from A + corresponding target sample from B) and "fake" pairs (input sample from A + corresponding output from G) [1, 2]

针对 D 训练条件生成器(与仅单独训练 G 相反,只有 L1/L2 损失,例如 DAE)提高了 G 的采样能力,迫使它输出清晰、逼真的结果,而不是试图平均分布。

即使鉴别器可以有多个子网络来覆盖其他任务(见下一段),D 应该至少保留一个 sub-network/output 来覆盖其主要任务:区分真实样本和生成样本。要求 D 回归进一步的语义信息(例如 classes)可能会干扰这个主要目的。

Note: D output is often not a simple scalar / boolean. It is common to have a discriminator (e.g. PatchGAN [1, 2]) returning a matrix of probabilities, evaluating how realistic patches made from its input are.


条件 GAN

传统的 GAN 以无监督的方式进行训练,以从随机噪声向量作为输入生成逼真的数据(例如图像)。 [4]

如前所述,条件 GAN 有进一步的输入 条件。 Along/instead 的噪声向量,他们将来自域 A 的样本和 return 来自域 B 的相应样本作为输入。 A 可以是完全不同的方式,例如B = sketch imageA = discrete label ; B = volumetric dataA = RGB image,等等 [3]

这样的 GAN 也可以通过多个输入进行调节,例如A = real image + discrete labelB = sketch image。引入此类方法的著名作品是 InfoGAN [5]。它介绍了如何在多个连续或离散输入(例如 A = digit class + writing typeB = handwritten digit image)上调节 GAN, 使用更高级的鉴别器,该鉴别器具有第二个任务以强制 G最大化其调节输入与其对应输出之间的互信息.


最大化 cGAN 的互信息

InfoGAN 鉴别器有 2 个 heads/sub-networks 来涵盖它的 2 个任务 [5]:

  • 一个头 D1 进行传统的 real/generated 鉴别——G 必须最小化这个结果,即它必须愚弄 D1 这样它就不能区分真实形式生成的数据;
  • 另一个头D2(也称为Q网络)试图回归输入A信息——G必须最大化这个结果,即它必须"show" 请求的语义信息的输出数据(c.f。G 条件输入及其输出之间的互信息最大化)。

例如,您可以在此处找到 Keras 实现:https://github.com/eriklindernoren/Keras-GAN/tree/master/infogan

一些作品正在使用类似的方案来改进对 GAN 生成内容的控制,方法是使用提供的标签并最大化这些输入和 G 输出之间的互信息 [6, 7]。基本思想总是一样的:

  • 训练 G 生成域 B 的元素,给定域 A 的一些输入;
  • 训练 D 以区分 "real"/"fake" 结果 -- G 必须将其最小化;
  • 训练 Q(例如 classifier ;可以与 D 共享层)以估计来自 B 个样本的原始 A 输入 -- G 必须最大化这个)。

总结

在你的例子中,你似乎有以下训练数据:

  • 真实图片Ia
  • 相应的草图图像Ib
  • 对应的class个标签c

并且您想训练一个生成器 G 以便给定一个图像 Ia 及其 class 标签 c,它输出一个正确的草图图像 Ib'.

总而言之,你有很多信息,你可以在条件图像和条件标签上监督你的训练...... 受上述方法 [1, 2, 5, 6, 7] 的启发,这里有一种使用所有这些信息来训练条件 G:

的可能方法 网络G:
  • 输入:Ia + c
  • 输出:Ib'
  • 架构:由您决定(例如 U-Net、ResNet 等)
  • 损失:L1/L2损失在Ib'Ib之间,-D损失,Q损失
网络D:
  • 输入:Ia + Ib(真实对),Ia + Ib'(假对)
  • 输出:"fakeness"scalar/matrix
  • 架构:由您决定(例如 PatchGAN)
  • 损失:"fakeness"估计的交叉熵
网络Q:
  • 输入:Ib(真实样本,用于训练 Q),Ib'(假样本,通过 G 反向传播时)
  • 输出:c'(估计 class)
  • 架构:由您决定(例如 LeNet、ResNet、VGG 等)
  • 损失:cc'之间的交叉熵
训练阶段:
  1. 在一批真实对 Ia + Ib 上训练 D 然后在一批假对 Ia + Ib';
  2. 在一批真实样本上训练 Q Ib
  3. 修复 DQ 权重;
  4. 训练 G,将其生成的输出 Ib' 传递给 DQ 通过它们进行反向传播。

Note: this is a really rough architecture description. I'd recommend going through the literature ([1, 5, 6, 7] as a good start) to get more details and maybe a more elaborate solution.


参考资料

  1. 伊索拉、菲利普等人。 "Image-to-image translation with conditional adversarial networks." arXiv 预印本 (2017)。 http://openaccess.thecvf.com/content_cvpr_2017/papers/Isola_Image-To-Image_Translation_With_CVPR_2017_paper.pdf
  2. 朱俊彦等。 "Unpaired image-to-image translation using cycle-consistent adversarial networks." arXiv 预印本 arXiv:1703.10593 (2017)。 http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf
  3. 米尔扎、迈赫迪和西蒙·奥辛德罗。 "Conditional generative adversarial nets." arXiv 预印本 arXiv:1411.1784 (2014)。 https://arxiv.org/pdf/1411.1784
  4. Goodfellow、Ian 等人。 "Generative adversarial nets." 神经信息处理系统的进展。 2014. http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf
  5. 陈曦等。 "Infogan: Interpretable representation learning by information maximizing generative adversarial nets." 神经信息处理系统的进展。 2016. http://papers.nips.cc/paper/6399-infogan-interpretable-representation-learning-by-information-maximizing-generative-adversarial-nets.pdf
  6. Lee、Minhyeok 和 Junhee Seok。 "Controllable Generative Adversarial Network." arXiv 预印本 arXiv:1708.00598 (2017)。 https://arxiv.org/pdf/1708.00598.pdf
  7. Odena、Augustus、Christopher Olah 和 Jonathon Shlens。 "Conditional image synthesis with auxiliary classifier gans." arXiv 预印本 arXiv:1610.09585 (2016)。 http://proceedings.mlr.press/v70/odena17a/odena17a.pdf