TF.data.dataset.map(map_func) 急切模式

TF.data.dataset.map(map_func) with Eager Mode

我正在使用启用了急切模式的 TF 1.8。

我无法打印 mapfunc 中的示例。当我从 mapfunc 中 运行 tf.executing_eagerly() 我得到 "False"

import os
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

tfe = tf.contrib.eager
tf.enable_eager_execution()
x = tf.random_uniform([16,10], -10, 0, tf.int64)
print(x)
DS = tf.data.Dataset.from_tensor_slices((x))


def mapfunc(ex, con):
    import pdb; pdb.set_trace()
    new_ex = ex + con
    print(new_ex) 
    return new_ex

DS = DS.map(lambda x: mapfunc(x, [7]))
DS = DS.make_one_shot_iterator()

print(DS.next())

print(new_ex) 输出:

Tensor("add:0", shape=(10,), dtype=int64)

在 mapfunc 之外,它工作正常。但在其中,传递的示例没有值,也没有 .numpy() 属性。

tf.data 转换实际上作为图形执行,因此 map 函数本身的主体不会急切执行。有关此问题的更多讨论,请参阅 #14732

如果你真的需要 map 函数的急切执行,你可以使用 tf.contrib.eager.py_func,比如:

DS = DS.map(lambda x: tf.contrib.eager.py_func(
  mapfunc,
  [x, tf.constant(7, dtype=tf.int64)], tf.int64)
# In TF 1.9+, the next line can be print(next(DS))
print(DS.make_one_shot_iterator().next())

希望对您有所帮助。

请注意,通过向数据集添加 py_func,单线程 Python 解释器将在每个生成的元素的循环中。

地图内的任何内容都是 运行 图形,无论外部使用何种模式。参见 https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map

如页面所示,有3个选项:

  1. Rely on AutoGraph to convert Python code into an equivalent graph computation. The downside of this approach is that AutoGraph can convert some but not all Python code.
  2. Use tf.py_function, which allows you to write arbitrary Python code but will generally result in worse performance than 1)
  3. Use tf.numpy_function, which also allows you to write arbitrary Python code. Note that tf.py_function accepts tf.Tensor whereas tf.numpy_function accepts numpy arrays and returns only numpy arrays.

使用 tf.py_function() 你的行将变成:

DS = DS.map(lambda y: tf.py_function(
                          (lambda x: mapfunc(x, [7])),
                          inp=[y], Tout=tf.int64
                      ))

这同样适用于 tf.map_fn() 和 tf.vectorized_map()。