运行 使用更多示例训练鉴别器

Running training the discriminator with more examples

据我了解,常规 GAN 与 WGAN 之间的区别在于我们在每个时期用更多示例训练 discriminator/critic。如果在常规 gan 中,我们在每个 epoch 中为两个模块分配了一批,那么在 WGAN 中,我们将有 5 个(或更多)批次用于鉴别器,一个用于生成器。

所以基本上我们有另一个鉴别器的内部循环:

real_images_labels = np.ones((BATCH_SIZE, 1))
 fake_images_labels = -real_images_labels
 for epoch in range(epochs):
    for batch in range(NUM_BACHES):
        for critic_iter in range(n_critic):
        random_batches_idx = np.random.randint(0, NUM_BACHES) # Choose random batch from dataset
        imgs_data=dataset_list[random_batches_idx]
        c_loss_real = critic.train_on_batch(imgs_data, real_images_labels) # update the weights after 1 batch

        noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
        generated_images = generator(noise, training=True)
        c_loss_fake = critic.train_on_batch(generated_images, fake_images_labels)  # update the weights after 1 batch
      
    
      imgs_data=dataset_list[batch]
      noise = tf.random.normal([imgs_data.shape[0], noise_dim]) # Generate noise data
      gen_loss_batch = gen_loss_batch + gan.train_on_batch(noise,real_images_labels)

训练花了我很多时间,每个 epoch 大约 3m。我不得不减少训练时间的想法是 运行 将每个批次向前 n_critic 次我可以增加鉴别器的 batch_size 和 运行 向前一次更大 batch_size.

我正在寻求反馈:听起来合理吗?

(我没有粘贴我的全部代码,它只是其中的一部分)。

是的,这听起来很合理,通常 在训练期间增加 batch_size,通常会以使用 更多内存为代价减少训练时间精度较低(泛化能力较低)

话虽如此,您应该始终对批处理进行反复试验,因为极值可能会也可能不会增加训练时间。

进一步的讨论可以参考这个question