在 tf.data.Dataset 映射方法中返回字符串
Returning strings in tf.data.Dataset map method
在 Tensorflow 1.4.1 中,tf.data.Dataset 中的 map 方法可以 return 字符串,所以我可以 return 类似这样的东西,我的 map 函数:
return filename, image, one_hot_label
其中 filename
是字符串。这在 TF1.5+ 中不再起作用:
dataset = dataset.map(self._mapper)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 838, in map
return MapDataset(self, map_func)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1826, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 488, in add_to_graph
self._create_definition_if_needed()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in tf_map_func
ret, [t.get_shape() for t in nest.flatten(ret)])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in <listcomp>
ret, [t.get_shape() for t in nest.flatten(ret)])
AttributeError: 'str' object has no attribute 'get_shape'
这是设计使然还是倒退?
一个可重现的例子:
import tensorflow as tf
def map_fn(x):
return x*2, 'foo'
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)
正如评论中所讨论的,这似乎是 TF 1.5 中的一个错误,至少 1.6,也可能是 1.7。我已经在 https://github.com/tensorflow/tensorflow/issues/18355
上开了一个 Github issue
在未来的 Tensorflow 版本中解决该问题之前,我建议将字符串输出显式转换为张量:
import tensorflow as tf
def map_fn(x):
# Explicitly convert 'foo' to tensor
return x*2, tf.convert_to_tensor('foo')
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)
在 Tensorflow 1.4.1 中,tf.data.Dataset 中的 map 方法可以 return 字符串,所以我可以 return 类似这样的东西,我的 map 函数:
return filename, image, one_hot_label
其中 filename
是字符串。这在 TF1.5+ 中不再起作用:
dataset = dataset.map(self._mapper)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 838, in map
return MapDataset(self, map_func)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1826, in __init__
self._map_func.add_to_graph(ops.get_default_graph())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 488, in add_to_graph
self._create_definition_if_needed()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 321, in _create_definition_if_needed
self._create_definition_if_needed_impl()
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/function.py", line 338, in _create_definition_if_needed_impl
outputs = self._func(*inputs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in tf_map_func
ret, [t.get_shape() for t in nest.flatten(ret)])
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1814, in <listcomp>
ret, [t.get_shape() for t in nest.flatten(ret)])
AttributeError: 'str' object has no attribute 'get_shape'
这是设计使然还是倒退?
一个可重现的例子:
import tensorflow as tf
def map_fn(x):
return x*2, 'foo'
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)
正如评论中所讨论的,这似乎是 TF 1.5 中的一个错误,至少 1.6,也可能是 1.7。我已经在 https://github.com/tensorflow/tensorflow/issues/18355
上开了一个 Github issue在未来的 Tensorflow 版本中解决该问题之前,我建议将字符串输出显式转换为张量:
import tensorflow as tf
def map_fn(x):
# Explicitly convert 'foo' to tensor
return x*2, tf.convert_to_tensor('foo')
dataset = tf.data.Dataset.range(5)
dataset = dataset.map(map_fn)