使用 tf.train.Checkpoint 在 keras 中保存 GAN
Saving a GAN in keras using tf.train.Checkpoint
更新:为了解决这个问题,我保持检查点结构不变,但编写了一个自定义 train_step 函数,在回购的帮助下 linked 在接受的问题答案中 linked 下面,计算梯度并使用 apply_weights 而不是编译模型并使用 train_on_batch。这样可以恢复完整的 GAN 状态。遗憾的是,通过这种方法,我相当确定 dropout 层不再起作用,因为鉴别器能够在训练的早期就完美地工作,这会阻止模型正确训练。不过,原来的问题还是解决了。
原文:
我目前正在 keras 中训练一个 GAN 并尝试制作它,以便我可以保存模型并稍后恢复训练。通常在 keras 中,您只需使用 model.save(),但是对于 GAN,如果鉴别器和 GAN(组合的生成器和鉴别器,鉴别器权重不可训练)模型分别保存和加载,那么 link它们之间的连接被破坏,GAN 将无法按预期运行。有人在这里问了一个类似的问题,,并被告知使用 tf.train.Checkpoint 来立即保存完整模型作为检查点。
我试过如下实现:
def train(epochs, batch_size):
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
d_optimizer=d_optimizer,
generator=generator,
discriminator=discriminator,
gan=gan
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
for epoch in epochs:
#train discriminator...
#train generator...
ckpt_manager.save()
其中 g_optimizer、d_optimizer 只是 tf.keras.optimizers.Adam
个对象,生成器、鉴别器和 gan 是 tf.keras.Model
个对象。
当我使用这种方法时,在检查点加载后,gan 模型和鉴别器之间的 link 被保留。训练一开始工作正常,但在我停止然后使用检查点恢复训练后,鉴别器损失开始大量增加,生成的数据变得毫无意义。
重新编译模型正在加载检查点,这是我唯一能想到的使用优化器最后状态的方法,但显然有些地方不对——而不是从原来的地方恢复训练,这种方法极大地扰乱了训练。
我是否错误地使用了 tf.train.Checkpoint 来完成我想做的事情?如果您需要更多信息来解决问题,请告诉我。
编辑,已按要求添加完整代码:
这是首先创建模型然后训练它们的代码,在此设置中,模型在首次创建时进行初始编译,如果使用最新的优化器状态从检查点恢复,则再次编译。我明白编译两次很奇怪,但我想不出另一种方法来使用检查点的最新优化器状态,如果有更好的方法我很乐意改变它。请注意,不寻常的基于 GRU 的 GAN 是因为我正在测试能够生成可变长度的时间序列。那里有很多特定于数据的东西,但希望总的来说它是有意义的。 train_df
只是一个包含所有训练数据的 pandas DataFrame
def build_generator():
input = Input(shape=(None, latent_dims))
gru1 = GRU(100, activation='relu', return_sequences=True)(input)
gru2 = GRU(100, activation='relu', return_sequences=True (gru1)
output = GRU(9, return_sequences=True, activation='sigmoid')(gru2)
model = Model(input, output)
return model
def build_discriminator():
input = Input(shape=(None, 9))
gru1 = GRU(100, return_sequences=True)(input)
gru2 = GRU(100, return_sequences=True)(gru1)
output = GRU(1, activation='sigmoid')(gru2)
model = Model(input, output)
return model
d_optimizer = opt.Adam(learning_rate=lr)
g_optimizer = opt.Adam(learning_rate=lr)
# Build discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
# Build generator
generator = build_generator()
# Build combined model
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
def train(epochs, batch_size=1): #Only works with batch size of 1 currently
sne = train_df.sn.unique()
n_batches = int(len(sne) / batch_size)
rng = np.random.default_rng(123)
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
d_optimizer=d_optimizer,
generator=generator,
discriminator=discriminator,
gan=gan
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
for epoch in range(epochs):
rng.shuffle(sne)
g_losses, d_losses = [], []
for batch in range(n_batches):
real = np.random.uniform(0.0, 0.1, (batch_size, 1)) # Used instead of np.zeros to avoid zero gradients
fake = np.random.uniform(0.9, 1.0, (batch_size, 1)) # Used instead of np.ones to avoid zero gradients
# Select real data
sn = sne[batch]
sndf = train_df[train_df.sn == sn]
X = sndf[['g_t', 'r_t', 'i_t', 'z_t', 'g', 'r', 'i', 'z', 'g_err', 'r_err', 'i_err', 'z_err']].values
X = X.reshape((1, *X.shape))
noise = rand.normal(size=(batch_size, latent_dims))
noise = np.reshape(noise, (batch_size, 1, latent_dims))
noise = np.repeat(noise, X.shape[1], 1)
gen_lcs = generator.predict(noise)
# Train discriminator
d_loss_real = discriminator.train_on_batch(X, real)
d_loss_fake = discriminator.train_on_batch(gen_lcs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train generator
noise = rand.normal(size=(2 * batch_size, latent_dims))
noise = np.reshape(noise, (2 * batch_size, 1, latent_dims))
noise = np.repeat(noise, X.shape[1], 1)
gen_labels = np.zeros((2 * batch_size, 1))
g_loss = gan.train_on_batch(noise, gen_labels)
g_losses.append(g_loss)
d_losses.append(d_loss)
ckpt_manager.save()
full_g_loss = np.mean(g_losses)
full_d_loss = np.mean(d_losses)
print(f'{epoch + 1}/{epochs} g_loss={full_g_loss}, d_loss={full_d_loss})
train()
如果您有以下检查点结构,您的模型应该可以正常工作:
checkpoint_dir = 'checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_opt=generator_opt,
discriminator_opt=discriminator_opt,
gan_opt=gan_opt,
generator=generator,
discriminator=discriminator,
GAN = GAN
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
请注意,GAN
模型有自己的优化器。然后在你的训练循环中,每隔一定时间保存检查点,例如每 10 个时期。
for epoch in range(epochs):
...
...
...
if epoch%10 == 0:
ckpt_manager.save()
更新:为了解决这个问题,我保持检查点结构不变,但编写了一个自定义 train_step 函数,在回购的帮助下 linked 在接受的问题答案中 linked 下面,计算梯度并使用 apply_weights 而不是编译模型并使用 train_on_batch。这样可以恢复完整的 GAN 状态。遗憾的是,通过这种方法,我相当确定 dropout 层不再起作用,因为鉴别器能够在训练的早期就完美地工作,这会阻止模型正确训练。不过,原来的问题还是解决了。
原文:
我目前正在 keras 中训练一个 GAN 并尝试制作它,以便我可以保存模型并稍后恢复训练。通常在 keras 中,您只需使用 model.save(),但是对于 GAN,如果鉴别器和 GAN(组合的生成器和鉴别器,鉴别器权重不可训练)模型分别保存和加载,那么 link它们之间的连接被破坏,GAN 将无法按预期运行。有人在这里问了一个类似的问题,
我试过如下实现:
def train(epochs, batch_size):
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
d_optimizer=d_optimizer,
generator=generator,
discriminator=discriminator,
gan=gan
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
for epoch in epochs:
#train discriminator...
#train generator...
ckpt_manager.save()
其中 g_optimizer、d_optimizer 只是 tf.keras.optimizers.Adam
个对象,生成器、鉴别器和 gan 是 tf.keras.Model
个对象。
当我使用这种方法时,在检查点加载后,gan 模型和鉴别器之间的 link 被保留。训练一开始工作正常,但在我停止然后使用检查点恢复训练后,鉴别器损失开始大量增加,生成的数据变得毫无意义。
重新编译模型正在加载检查点,这是我唯一能想到的使用优化器最后状态的方法,但显然有些地方不对——而不是从原来的地方恢复训练,这种方法极大地扰乱了训练。
我是否错误地使用了 tf.train.Checkpoint 来完成我想做的事情?如果您需要更多信息来解决问题,请告诉我。
编辑,已按要求添加完整代码:
这是首先创建模型然后训练它们的代码,在此设置中,模型在首次创建时进行初始编译,如果使用最新的优化器状态从检查点恢复,则再次编译。我明白编译两次很奇怪,但我想不出另一种方法来使用检查点的最新优化器状态,如果有更好的方法我很乐意改变它。请注意,不寻常的基于 GRU 的 GAN 是因为我正在测试能够生成可变长度的时间序列。那里有很多特定于数据的东西,但希望总的来说它是有意义的。 train_df
只是一个包含所有训练数据的 pandas DataFrame
def build_generator():
input = Input(shape=(None, latent_dims))
gru1 = GRU(100, activation='relu', return_sequences=True)(input)
gru2 = GRU(100, activation='relu', return_sequences=True (gru1)
output = GRU(9, return_sequences=True, activation='sigmoid')(gru2)
model = Model(input, output)
return model
def build_discriminator():
input = Input(shape=(None, 9))
gru1 = GRU(100, return_sequences=True)(input)
gru2 = GRU(100, return_sequences=True)(gru1)
output = GRU(1, activation='sigmoid')(gru2)
model = Model(input, output)
return model
d_optimizer = opt.Adam(learning_rate=lr)
g_optimizer = opt.Adam(learning_rate=lr)
# Build discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
# Build generator
generator = build_generator()
# Build combined model
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
def train(epochs, batch_size=1): #Only works with batch size of 1 currently
sne = train_df.sn.unique()
n_batches = int(len(sne) / batch_size)
rng = np.random.default_rng(123)
checkpoint = tf.train.Checkpoint(g_optimizer=g_optimizer,
d_optimizer=d_optimizer,
generator=generator,
discriminator=discriminator,
gan=gan
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, 'checkpoints', max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
discriminator.compile(loss='binary_crossentropy', optimizer=d_optimizer)
i = Input(shape=(None, latent_dims))
lcs = generator(i)
discriminator.trainable = False
valid = discriminator(lcs)
gan = Model(i, valid)
gan.compile(loss='binary_crossentropy', optimizer=g_optimizer)
for epoch in range(epochs):
rng.shuffle(sne)
g_losses, d_losses = [], []
for batch in range(n_batches):
real = np.random.uniform(0.0, 0.1, (batch_size, 1)) # Used instead of np.zeros to avoid zero gradients
fake = np.random.uniform(0.9, 1.0, (batch_size, 1)) # Used instead of np.ones to avoid zero gradients
# Select real data
sn = sne[batch]
sndf = train_df[train_df.sn == sn]
X = sndf[['g_t', 'r_t', 'i_t', 'z_t', 'g', 'r', 'i', 'z', 'g_err', 'r_err', 'i_err', 'z_err']].values
X = X.reshape((1, *X.shape))
noise = rand.normal(size=(batch_size, latent_dims))
noise = np.reshape(noise, (batch_size, 1, latent_dims))
noise = np.repeat(noise, X.shape[1], 1)
gen_lcs = generator.predict(noise)
# Train discriminator
d_loss_real = discriminator.train_on_batch(X, real)
d_loss_fake = discriminator.train_on_batch(gen_lcs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Train generator
noise = rand.normal(size=(2 * batch_size, latent_dims))
noise = np.reshape(noise, (2 * batch_size, 1, latent_dims))
noise = np.repeat(noise, X.shape[1], 1)
gen_labels = np.zeros((2 * batch_size, 1))
g_loss = gan.train_on_batch(noise, gen_labels)
g_losses.append(g_loss)
d_losses.append(d_loss)
ckpt_manager.save()
full_g_loss = np.mean(g_losses)
full_d_loss = np.mean(d_losses)
print(f'{epoch + 1}/{epochs} g_loss={full_g_loss}, d_loss={full_d_loss})
train()
如果您有以下检查点结构,您的模型应该可以正常工作:
checkpoint_dir = 'checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_opt=generator_opt,
discriminator_opt=discriminator_opt,
gan_opt=gan_opt,
generator=generator,
discriminator=discriminator,
GAN = GAN
)
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
checkpoint.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
请注意,GAN
模型有自己的优化器。然后在你的训练循环中,每隔一定时间保存检查点,例如每 10 个时期。
for epoch in range(epochs):
...
...
...
if epoch%10 == 0:
ckpt_manager.save()