如何在 GANs 训练中使用 TensorFlow Dataset API?
How to use TensorFlow Dataset API in GANs training?
我正在训练 GAN 模型。为了加载数据集,我使用的是 TensorFlow 的数据集 API。
# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length, z_dim],
minval=0, maxval=1, dtype=tf.float32))
train_dataset = tf.data.Dataset.zip((train_dataset, z_train))
正在创建迭代器:
iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
使用迭代器:
(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)
在会话中训练 GAN 时:
首先训练判别器:
_, disc_loss = sess.run([disc_optim, disc_loss])
然后训练生成器:
_, gen_loss = sess.run([gen_optim, gen_loss])
问题来了。因为,我在判别器和生成器图中都使用 label 作为条件(CGAN),使用两个 sess.run 产生两组不同的 label 在同一个 运行 批次中。
for epoch in range(num_of_epochs):
sess.run([tf.global_variables_initializer(), train_init_op.initializer])
for batch in range(num_of_batches):
_, disc_loss = sess.run([disc_optim, disc_loss])
_, gen_loss = sess.run([gen_optim, gen_loss])
因为,我必须在生成器会话 运行 和鉴别器会话 运行 中提供同一批 label,我该如何防止数据集API 在一个批次的同一个循环中生产两个不同的批次?
注意:我使用的是 TensorFlow v1.9
提前致谢。
您可以为同一个数据集创建 2 个迭代器。如果您需要打乱数据集,您甚至可以通过将种子指定为张量来实现。请参阅下面的示例。
import tensorflow as tf
seed_ts = tf.placeholder(tf.int64)
ds = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).shuffle(5, seed=seed_ts, reshuffle_each_iteration=True)
it1 = ds.make_initializable_iterator()
it2 = ds.make_initializable_iterator()
input1 = it1.get_next()
input2 = it2.get_next()
with tf.Session() as sess:
for ep in range(10):
sess.run(it1.initializer, feed_dict={seed_ts: ep})
sess.run(it2.initializer, feed_dict={seed_ts: ep})
print("Epoch" + str(ep))
for i in range(5):
x = sess.run(input1)
y = sess.run(input2)
print([x, y])
我正在训练 GAN 模型。为了加载数据集,我使用的是 TensorFlow 的数据集 API。
# train_dataset has image and label. z_train dataset has noise (z).
train_dataset = tf.data.TFRecordDataset(train_file)
z_train = tf.data.Dataset.from_tensor_slices(tf.random_uniform([total_training_samples, seq_length, z_dim],
minval=0, maxval=1, dtype=tf.float32))
train_dataset = tf.data.Dataset.zip((train_dataset, z_train))
正在创建迭代器:
iter = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)
使用迭代器:
(img, label), z = iter.get_next()
train_init_op = iter.make_initializer(train_dataset)
在会话中训练 GAN 时:
首先训练判别器:
_, disc_loss = sess.run([disc_optim, disc_loss])
然后训练生成器:
_, gen_loss = sess.run([gen_optim, gen_loss])
问题来了。因为,我在判别器和生成器图中都使用 label 作为条件(CGAN),使用两个 sess.run 产生两组不同的 label 在同一个 运行 批次中。
for epoch in range(num_of_epochs):
sess.run([tf.global_variables_initializer(), train_init_op.initializer])
for batch in range(num_of_batches):
_, disc_loss = sess.run([disc_optim, disc_loss])
_, gen_loss = sess.run([gen_optim, gen_loss])
因为,我必须在生成器会话 运行 和鉴别器会话 运行 中提供同一批 label,我该如何防止数据集API 在一个批次的同一个循环中生产两个不同的批次?
注意:我使用的是 TensorFlow v1.9
提前致谢。
您可以为同一个数据集创建 2 个迭代器。如果您需要打乱数据集,您甚至可以通过将种子指定为张量来实现。请参阅下面的示例。
import tensorflow as tf
seed_ts = tf.placeholder(tf.int64)
ds = tf.data.Dataset.from_tensor_slices([1,2,3,4,5]).shuffle(5, seed=seed_ts, reshuffle_each_iteration=True)
it1 = ds.make_initializable_iterator()
it2 = ds.make_initializable_iterator()
input1 = it1.get_next()
input2 = it2.get_next()
with tf.Session() as sess:
for ep in range(10):
sess.run(it1.initializer, feed_dict={seed_ts: ep})
sess.run(it2.initializer, feed_dict={seed_ts: ep})
print("Epoch" + str(ep))
for i in range(5):
x = sess.run(input1)
y = sess.run(input2)
print([x, y])