如何在数据管道中获取当前 global_step

How to get current global_step in data pipeline

我正在尝试创建一个过滤器,它取决于当前的 global_step 训练,但我没有正确地做到这一点。

首先,我不能在下面的代码中使用 tf.train.get_or_create_global_step(),因为它会抛出

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

这就是为什么我尝试使用 tf.get_default_graph().get_name_scope() 获取范围并在该上下文中我能够“get”全局步骤的原因:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens


dataset = dataset.filter(filter_examples)

问题是它似乎没有像我预期的那样工作。根据我的观察,上面代码中的 current_step 似乎一直为 0(我不知道,只是根据我的观察我假设)。

唯一似乎有所作为并且听起来很奇怪的事情是重新开始训练。我认为,同样基于观察,在这种情况下 current_step 将是此时训练的实际当前步骤。但是随着训练的继续,这个值本身不会更新。

是否有办法获取当前步骤的 实际 值并像上面那样在我的过滤器中使用它?


环境

张量流 1.12.1

如果您在数据集中使用变量,则需要在 tf 1.x 中重新初始化迭代器。

iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()

with tf.compat.v1.Session() as sess:
    for epoch in range(num_epochs):
        sess.run(init)
        for example in range(num_examples):
            tensor_vals = sess.run(tensors)

正如我们在评论中讨论的那样,拥有并更新您自己的计数器可能是使用 global_step 变量的替代方法。 counter 变量可以更新如下:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

使用 tf.control_dependencies 允许 "attach" 将 counter 更新为计算图中的路径。然后,您可以在任何需要的地方使用 counter 变量。