在估计器模型函数中使用 tf.cond() 在 TPU 上训练 WGAN 导致加倍 global_step
Using tf.cond() in an estimator model function to train a WGAN on a TPU causes doubled global_step
我正在尝试在 TPU 上训练 GAN,所以我一直在摆弄 TPUEstimator class 和附带的模型函数来尝试实现 WGAN 训练循环。我正在尝试使用 tf.cond
合并 TPUEstimatorSpec 的两个训练操作:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
和 critic_opt
是我正在使用的优化器的最小化功能,也设置为更新全局步骤。 CRITIC_UPDATES_PER_GEN_UPDATE
是一个 python 常量,是 WGAN 训练的一部分。我试图找到一个使用 tf.cond
的 GAN 模型,但所有模型都使用 tf.group
,我不能使用它,因为你需要比生成器优化评论家更多次。
但是我每运行100个batches,global step根据checkpoint number增加200。我的模型是否仍在正确训练,或者 tf.cond
只是不应该以这种方式训练 GAN?
tf.cond
不应该以这种方式用于训练 GAN。
你得到 200,因为每个训练步骤都会评估 true_fn
和 false_fn
的副作用(如赋值操作)。副作用之一是两个优化器都定义的全局步骤 tf.assign_add
操作。
因此,发生的事情就像
global_step++ (gen_opt)
和 global_step++ (critic_op)
的执行
- 条件评估
- 执行
true_fn
正文或 false_fn
正文(取决于条件)。
如果你想使用 tf.cond
训练 GAN,你必须从 true_fn
/ false_fn
并在其中声明所有内容。
作为参考,您可以查看有关 tf.cond
行为的答案:
我正在尝试在 TPU 上训练 GAN,所以我一直在摆弄 TPUEstimator class 和附带的模型函数来尝试实现 WGAN 训练循环。我正在尝试使用 tf.cond
合并 TPUEstimatorSpec 的两个训练操作:
opt = tf.cond(
tf.equal(tf.mod(tf.train.get_or_create_global_step(),
CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1),
lambda: gen_opt,
lambda: critic_opt
)
gen_opt
和 critic_opt
是我正在使用的优化器的最小化功能,也设置为更新全局步骤。 CRITIC_UPDATES_PER_GEN_UPDATE
是一个 python 常量,是 WGAN 训练的一部分。我试图找到一个使用 tf.cond
的 GAN 模型,但所有模型都使用 tf.group
,我不能使用它,因为你需要比生成器优化评论家更多次。
但是我每运行100个batches,global step根据checkpoint number增加200。我的模型是否仍在正确训练,或者 tf.cond
只是不应该以这种方式训练 GAN?
tf.cond
不应该以这种方式用于训练 GAN。
你得到 200,因为每个训练步骤都会评估 true_fn
和 false_fn
的副作用(如赋值操作)。副作用之一是两个优化器都定义的全局步骤 tf.assign_add
操作。
因此,发生的事情就像
global_step++ (gen_opt)
和global_step++ (critic_op)
的执行
- 条件评估
- 执行
true_fn
正文或false_fn
正文(取决于条件)。
如果你想使用 tf.cond
训练 GAN,你必须从 true_fn
/ false_fn
并在其中声明所有内容。
作为参考,您可以查看有关 tf.cond
行为的答案: