py_func 无法处理超过 9 个项目的列表

py_func is unable to handle list with more than 9 items

我正在尝试使用 Dataset.map 向函数提供两个值,并且函数 return 返回一个值列表。当 return 列表包含超过八个元素时, py_func 无法处理类型。当 map_func 中的 return 列表包含八个元素时,没有问题。

tensorflow 版本在 Trisquel 发行版中为 1.4.1

成功案例

import tensorflow as tf

def gen_range(groups=4, limit=1000):
    jump = limit/groups
    start, stop = 0, 0
    while stop != limit:
        stop = start + jump
        yield start, stop
        start = stop

def bridge(x, y):
    return [[[x] * 4, [y] * 4]]

with tf.Session() as sess:
    dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
        lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
        make_one_shot_iterator()
    init = tf.global_variables_initializer()
    sess.run(init)
    while True:
        print(sess.run(dataset.get_next()))
输出
(array([[  0,   0,   0,   0],
   [250, 250, 250, 250]], dtype=int32),)
(array([[250, 250, 250, 250],
   [500, 500, 500, 500]], dtype=int32),)
(array([[500, 500, 500, 500],
   [750, 750, 750, 750]], dtype=int32),)
2018-01-09 01:56:12.871943: W tensorflow/core/framework/op_kernel.cc:1192] Out of range: StopIteration: Iteration finished.
(array([[ 750,  750,  750,  750],
   [1000, 1000, 1000, 1000]], dtype=int32),)

失败案例

def bridge(x, y):
    return [[[x] * 5, [y] * 4]]

with tf.Session() as sess:
    dataset = tf.data.Dataset.from_generator(gen_range, (tf.int32, tf.int32)).map(
        lambda x, y: tf.py_func(bridge, [x, y], [tf.int32]), num_parallel_calls=2).\
        make_one_shot_iterator()
    init = tf.global_variables_initializer()
    sess.run(init)
    while True:
        print(sess.run(dataset.get_next()))
输出
2018-01-09 01:56:41.201683: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
2018-01-09 01:56:41.201960: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
 [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
 2018-01-09 01:56:41.202002: W tensorflow/core/framework/op_kernel.cc:1192] Unimplemented: Unsupported object type list
---------------------------------------------------------------------------
UnimplementedError                        Traceback (most recent call last)
/media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
1322     try:
-> 1323       return fn(*args)
   1324     except errors.OpError as e:

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)
 1301                                    feed_dict, fetch_list, target_list,
 -> 1302                                    status, run_metadata)
 1303 

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
 472             compat.as_text(c_api.TF_Message(self.status.status)),
 --> 473             c_api.TF_GetCode(self.status.status))
     474     # Delete the underlying status object from memory otherwise it stays alive

 UnimplementedError: Unsupported object type list
 [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
 [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=
 [<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]

 During handling of the above exception, another exception occurred:

 UnimplementedError                        Traceback (most recent call last)
 <ipython-input-120-120ecc56d75d> in <module>()
       4     sess.run(init)
       5     while True:
 ----> 6         print(sess.run(dataset.get_next()))
       7 
       8 

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
     887     try:
     888       result = self._run(None, fetches, feed_dict, options_ptr,
 --> 889                          run_metadata_ptr)
     890       if run_metadata:
     891         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
    1118     if final_fetches or final_targets or (handle and feed_dict_tensor):
    1119       results = self._do_run(handle, final_targets, final_fetches,
 -> 1120                              feed_dict_tensor, options, run_metadata)
    1121     else:
    1122       results = []

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
 1315     if handle is None:
 1316       return self._do_call(_run_fn, self._session, feeds, fetches, targets,
 -> 1317                            options, run_metadata)
    1318     else:
    1319       return self._do_call(_prun_fn, self._session, handle, feeds, fetches)

 /media/user/code/tf_data_exploration/venv/lib/python3.6/site-packages/tensorflow/python/client/session.py in _do_call(self, fn, *args)
    1334         except KeyError:
    1335           pass
 -> 1336       raise type(e)(node_def, op, message)
    1337 
    1338   def _extend_graph(self):

 UnimplementedError: Unsupported object type list
     [[Node: PyFunc = PyFunc[Tin=[DT_INT32, DT_INT32], Tout=[DT_INT32], token="pyfunc_129"](arg0, arg1)]]
     [[Node: IteratorGetNext_95 = IteratorGetNext[output_shapes=[<unknown>], output_types=[DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator_35)]]

问题不在于尺寸,而在于您 return 无法转换为 Tensor。

当你return[[[x] * 4, [y] * 4]]时,这可以转换为形状(1, 2, 4):

的Tensor
res = tf.constant([[[x] * 4, [y] * 4]])
print(res.get_shape())  # prints (1, 2, 4)

当你 return [[[x] * 5, [y] * 4]] 时,你会为 x=1, y=2 获得这样的东西:

[[[1, 1, 1, 1, 1],
  [2, 2, 2, 2]
]]

这无法转换为 Tensor,因为第一行和第二行的维度不匹配。


如果您尝试这样做,您可能会触发类似的错误:

res = tf.constant([[1, 2], [3]])

Argument must be a dense tensor: [[1, 2], [3]] - got shape [2], but wanted [2, 2].

TensorFlow 无法推断您的张量的形状。