在估计器模型函数中使用 tf.cond() 在 TPU 上训练 WGAN 导致加倍 global_step

Using tf.cond() in an estimator model function to train a WGAN on a TPU causes doubled global_step

我正在尝试在 TPU 上训练 GAN,所以我一直在摆弄 TPUEstimator class 和附带的模型函数来尝试实现 WGAN 训练循​​环。我正在尝试使用 tf.cond 合并 TPUEstimatorSpec 的两个训练操作:

opt = tf.cond(
    tf.equal(tf.mod(tf.train.get_or_create_global_step(), 
    CRITIC_UPDATES_PER_GEN_UPDATE+1), CRITIC_UPDATES_PER_GEN_UPDATE+1), 
    lambda: gen_opt, 
    lambda: critic_opt
)

gen_optcritic_opt 是我正在使用的优化器的最小化功能,也设置为更新全局步骤。 CRITIC_UPDATES_PER_GEN_UPDATE 是一个 python 常量,是 WGAN 训练的一部分。我试图找到一个使用 tf.cond 的 GAN 模型,但所有模型都使用 tf.group,我不能使用它,因为你需要比生成器优化评论家更多次。 但是我每运行100个batches,global step根据checkpoint number增加200。我的模型是否仍在正确训练,或者 tf.cond 只是不应该以这种方式训练 GAN?

tf.cond 不应该以这种方式用于训练 GAN。

你得到 200,因为每个训练步骤都会评估 true_fnfalse_fn 的副作用(如赋值操作)。副作用之一是两个优化器都定义的全局步骤 tf.assign_add 操作。

因此,发生的事情就像

  • global_step++ (gen_opt)global_step++ (critic_op)
  • 的执行
  • 条件评估
  • 执行 true_fn 正文或 false_fn 正文(取决于条件)。

如果你想使用 tf.cond 训练 GAN,你必须从 true_fn/ false_fn 并在其中声明所有内容。

作为参考,您可以查看有关 tf.cond 行为的答案: