在tf.data中使用tf.while_loop修改数组内容报错

Using tf.while_loop in tf.data to modify the content of the array causes an error

我想用for loop来操作tf.data中的一个数组。对于现在的tf.while_loop方法,我必须要匹配输入参数和输出,所以我提前创建了一个数组new_data,然后用tf.while_loop顺序修改数组的内容,但是结果往往是错误的。我是不是做错了什么?

代码:

import tensorflow as tf
import numpy as np

def body(index, new_data):
    tf.Variable(lambda:new_data[index, 0]).assign(1)  #An error occurred
    tf.Variable(lambda:new_data[index, 0]).assign(1)
    return tf.add(index,1), new_data
    
def main(data):
    new_data = tf.zeros((200,2), dtype=tf.float64)
    index = tf.constant(0)
    condition = lambda index, new_data: tf.less(index, 200)
    r = tf.while_loop(condition, body, loop_vars=(index, new_data))
    return data
    
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
        trainDS
        .map(main, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(10, drop_remainder=True)
        .prefetch(tf.data.AUTOTUNE))

for i in trainDS:
    i

错误:

    test2.py:13 main  *
        r = tf.while_loop(condition, body, loop_vars=(index, new_data))
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:605 new_func  **
        return func(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
        return_same_structure=True)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
        back_prop=back_prop)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:200 while_loop
        add_control_dependencies=add_control_dependencies)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:990 func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:178 wrapped_body
        outputs = body(*_pack_sequence_as(orig_loop_vars, args))
    test2.py:5 body
        tf.Variable(lambda:new_data[index, 0]).assign(1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:237 <lambda>
        previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py:2667 default_variable_creator_v2
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:264 __call__
        return super(VariableMetaclass, cls).__call__(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1585 __init__
        distribute_strategy=distribute_strategy)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1712 _init_from_args
        initial_value = initial_value()
    test2.py:5 <lambda>
        tf.Variable(lambda:new_data[index, 0]).assign(1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:1011 _slice_helper
        end.append(s + 1)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper
        return func(x, y, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
        return target(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1486 _add_dispatch
        return gen_math_ops.add_v2(x, y, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:477 add_v2
        x, y, name=name, ctx=_ctx)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:501 add_v2_eager_fallback
        ctx=ctx, name=name)
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
        raise e
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
        inputs, attrs, num_outputs)

    TypeError: An op outside of the function building code is being passed
    a "Graph" tensor. It is possible to have Graph tensors
    leak out of the function building context by including a
    tf.init_scope in your function building code.
    For example, the following function will fail:
      @tf.function
      def has_init_scope():
        my_constant = tf.constant(1.)
        with tf.init_scope():
          added = my_constant * 2
    The graph tensor has name: while/Placeholder:0

环境:python3.6 tensorflow-gpu2.4.0 Ubuntu20.04

您可以尝试将 tf.tensor_scatter_nd_update 用于您的用例,但请注意,当前为每个数据点或一批数据点调用 main 方法,因此 new_data 被重新初始化每次。

import tensorflow as tf
import numpy as np

def body(index, new_data):
    new_data = tf.tensor_scatter_nd_update(new_data, [[index, 0], [index, 0]], [1.0, 1.0])
    return tf.add(index,1), new_data
    
def main(data):
    new_data = tf.zeros((200,2), dtype=tf.float64)
    index = tf.constant(0)
    condition = lambda index, new_data: tf.less(index, 200)
    r = tf.while_loop(condition, body, loop_vars=(index, new_data))
    return data
    
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
        trainDS
        .map(main, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(10, drop_remainder=True)
        .prefetch(tf.data.AUTOTUNE))

for i in trainDS:
  print(i)