数据集的 from_generator 仅在急切模式下以 "Only integers are valid indices" 失败,在图形模式下没有错误

Dataset's from_generator fails with a "Only integers are valid indices" in eager mode only, no error in graph mode

我正在使用 from_generator 函数创建张量流 Dataset。在 graph/session 模式下,它工作正常:

import tensorflow as tf

x = {str(i): i for i in range(10)}

def gen():
  for i in x:
    yield x[i]

ds = tf.data.Dataset.from_generator(gen, tf.int32)
batch = ds.make_one_shot_iterator().get_next()

with tf.Session() as sess:
  while True:
    try:
      print(sess.run(batch), end=' ')
    except tf.errors.OutOfRangeError:
      break
# 0 1 2 3 4 5 6 7 8 9

然而,令人惊讶的是,它使用急切执行失败了:

import tensorflow as tf
tf.enable_eager_execution()

x = {str(i): i for i in range(10)}

def gen():
  for i in x:
    yield x[i]

ds = tf.data.Dataset.from_generator(gen, tf.int32)

for x in ds:
  print(x, end=' ')
# TypeError: Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got '1'

我假设,由于生成器的主体是纯粹的 python,不会被序列化,所以 tensorflow 不会查看——实际上不关心生成器中有什么。但事实显然并非如此。那么为什么 tensorflow 关心生成器里面有什么?假设无法更改生成器,是否有办法以某种方式解决此问题?

tl;dr 该问题与 TensorFlow 无关。您的循环变量阴影先前定义 x.

事实 1:Python 中的 for 循环没有名称空间,并将循环变量泄漏到周围的名称空间中(在您的示例中为 globals())。

事实 2:闭包是 "dynamic" 即 gen 生成器只知道它应该查找名称 "x" 来计算 x[i]x 的实际值将在迭代生成器时解决。

将这两个放在一起并展开 for 循环的前两次迭代,我们得到以下执行顺序:

ds = tf.data.Dataset.from_generator(gen, tf.int32)
it = iter(ds)
x = next(it)  # Calls to the generator which yields back x[i].
print(x, end='')
# Calls to the generator as before, but x is no longer a dict so x[i]
# is actually indexing into a Tensor!
x = next(it)  

修复很简单:使用不同的循环变量名称。

for item in ds:
  print(item, end=' ')