使用 tf.train.Checkpoint 在 keras 中保存 GAN

Saving a GAN in keras using tf.train.Checkpoint

更新:为了解决这个问题,我保持检查点结构不变,但编写了一个自定义 train_step 函数,在回购的帮助下 linked 在接受的问题答案中 linked 下面,计算梯度并使用 apply_weights 而不是编译模型并使用 train_on_batch。这样可以恢复完整的 GAN 状态。遗憾的是,通过这种方法,我相当确定 dropout 层不再起作用,因为鉴别器能够在训练的早期就完美地工作,这会阻止模型正确训练。不过,原来的问题还是解决了。

原文:

我目前正在 keras 中训练一个 GAN 并尝试制作它,以便我可以保存模型并稍后恢复训练。通常在 keras 中,您只需使用 model.save(),但是对于 GAN,如果鉴别器和 GAN(组合的生成器和鉴别器,鉴别器权重不可训练)模型分别保存和加载,那么 link它们之间的连接被破坏,GAN 将无法按预期运行。有人在这里问了一个类似的问题,,并被告知使用 tf.train.Checkpoint 来立即保存完整模型作为检查点。

我试过如下实现:

def train(epochs, batch_size):
    checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
                                     d_optimizer=d_optimizer,
                                     generator=generator,
                                     discriminator=discriminator,
                                     gan=gan
                                     )
    ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)

    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)
        discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)

        i = Input(shape=(None, latent_dims))
        lcs = generator(i)

        discriminator.trainable = False

        valid = discriminator(lcs)

        gan = Model(i, valid)
        gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)

    for epoch in epochs:
        #train discriminator...
        #train generator...
        ckpt_manager.save()

其中 g_optimizer、d_optimizer 只是 tf.keras.optimizers.Adam 个对象,生成器、鉴别器和 gan 是 tf.keras.Model 个对象。

当我使用这种方法时,在检查点加载后,gan 模型和鉴别器之间的 link 被保留。训练一开始工作正常,但在我停止然后使用检查点恢复训练后,鉴别器损失开始大量增加,生成的数据变得毫无意义。

重新编译模型正在加载检查点,这是我唯一能想到的使用优化器最后状态的方法,但显然有些地方不对——而不是从原来的地方恢复训练,这种方法极大地扰乱了训练。

我是否错误地使用了 tf.train.Checkpoint 来完成我想做的事情?如果您需要更多信息来解决问题,请告诉我。

编辑,已按要求添加完整代码:

这是首先创建模型然后训练它们的代码,在此设置中,模型在首次创建时进行初始编译,如果使用最新的优化器状态从检查点恢复,则再次编译。我明白编译两次很奇怪,但我想不出另一种方法来使用检查点的最新优化器状态,​​如果有更好的方法我很乐意改变它。请注意,不寻常的基于 GRU 的 GAN 是因为我正在测试能够生成可变长度的时间序列。那里有很多特定于数据的东西,但希望总的来说它是有意义的。 train_df 只是一个包含所有训练数据的 pandas DataFrame

def build_generator():
    input = Input(shape=(None, latent_dims))
    gru1 = GRU(100, activation='relu', return_sequences=True)(input)
    gru2 = GRU(100, activation='relu', return_sequences=True (gru1)
    output = GRU(9, return_sequences=True, activation='sigmoid')(gru2)
    model = Model(input, output)
    return model

def build_discriminator():
    input = Input(shape=(None, 9))
    gru1 = GRU(100, return_sequences=True)(input)
    gru2 = GRU(100, return_sequences=True)(gru1)
    output = GRU(1, activation='sigmoid')(gru2)
    model = Model(input, output)
    return model

d_optimizer = opt.Adam(learning_rate=lr)
g_optimizer = opt.Adam(learning_rate=lr)

# Build discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)

# Build generator
generator = build_generator()

# Build combined model
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)

gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)

def train(epochs, batch_size=1): #Only works with batch size of 1 currently
    sne = train_df.sn.unique()
    n_batches = int(len(sne) / batch_size)
    rng = np.random.default_rng(123)

    checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
                                     d_optimizer=d_optimizer,
                                     generator=generator,
                                     discriminator=discriminator,
                                     gan=gan
                                     )
    ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)
    if ckpt_manager.latest_checkpoint:
        checkpoint.restore(ckpt_manager.latest_checkpoint)
        discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)

        i = Input(shape=(None, latent_dims))
        lcs = generator(i)

        discriminator.trainable = False
        valid = discriminator(lcs)

        gan = Model(i, valid)
        gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)

    for epoch in range(epochs):
        rng.shuffle(sne)
        g_losses, d_losses = [], []
        for batch in range(n_batches):
            real = np.random.uniform(0.0, 0.1, (batch_size, 1)) # Used instead of np.zeros to avoid zero gradients
            fake = np.random.uniform(0.9, 1.0, (batch_size, 1)) # Used instead of np.ones to avoid zero gradients

        # Select real data
        sn = sne[batch]
        sndf = train_df[train_df.sn == sn]
        X = sndf[['g_t', 'r_t', 'i_t', 'z_t', 'g', 'r', 'i', 'z', 'g_err', 'r_err', 'i_err', 'z_err']].values

        X = X.reshape((1, *X.shape))

        noise = rand.normal(size=(batch_size, latent_dims))
        noise = np.reshape(noise, (batch_size, 1, latent_dims))
        noise = np.repeat(noise, X.shape[1], 1)

        gen_lcs = generator.predict(noise)

        # Train discriminator
        d_loss_real = discriminator.train_on_batch(X, real)
        d_loss_fake = discriminator.train_on_batch(gen_lcs, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        # Train generator
        noise = rand.normal(size=(2 * batch_size, latent_dims))
        noise = np.reshape(noise, (2 * batch_size, 1, latent_dims))
        noise = np.repeat(noise, X.shape[1], 1)

        gen_labels = np.zeros((2 * batch_size, 1))
        g_loss = gan.train_on_batch(noise, gen_labels)
        g_losses.append(g_loss)
        d_losses.append(d_loss)
    ckpt_manager.save()
    full_g_loss = np.mean(g_losses)
    full_d_loss = np.mean(d_losses)
    print(f'{epoch + 1}/{epochs} g_loss={full_g_loss}, d_loss={full_d_loss})

train()

如果您有以下检查点结构,您的模型应该可以正常工作:

checkpoint_dir = 'checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_opt=generator_opt,
                                  discriminator_opt=discriminator_opt,
                                  gan_opt=gan_opt,
                                  generator=generator,
                                  discriminator=discriminator,
                                  GAN = GAN
                                  )

ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
  checkpoint.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

请注意,GAN 模型有自己的优化器。然后在你的训练循环中,每隔一定时间保存检查点,例如每 10 个时期。

for epoch in range(epochs):
...
...
...
  if epoch%10 == 0:
    ckpt_manager.save()