在tf.data中使用tf.while_loop修改数组内容报错
Using tf.while_loop in tf.data to modify the content of the array causes an error
我想用for loop
来操作tf.data
中的一个数组。对于现在的tf.while_loop
方法,我必须要匹配输入参数和输出,所以我提前创建了一个数组new_data
,然后用tf.while_loop
顺序修改数组的内容,但是结果往往是错误的。我是不是做错了什么?
代码:
import tensorflow as tf
import numpy as np
def body(index, new_data):
tf.Variable(lambda:new_data[index, 0]).assign(1) #An error occurred
tf.Variable(lambda:new_data[index, 0]).assign(1)
return tf.add(index,1), new_data
def main(data):
new_data = tf.zeros((200,2), dtype=tf.float64)
index = tf.constant(0)
condition = lambda index, new_data: tf.less(index, 200)
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
return data
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
trainDS
.map(main, num_parallel_calls=tf.data.AUTOTUNE)
.batch(10, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
for i in trainDS:
i
错误:
test2.py:13 main *
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:605 new_func **
return func(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
return_same_structure=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
back_prop=back_prop)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:200 while_loop
add_control_dependencies=add_control_dependencies)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:990 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:178 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
test2.py:5 body
tf.Variable(lambda:new_data[index, 0]).assign(1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:237 <lambda>
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py:2667 default_variable_creator_v2
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:264 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1585 __init__
distribute_strategy=distribute_strategy)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1712 _init_from_args
initial_value = initial_value()
test2.py:5 <lambda>
tf.Variable(lambda:new_data[index, 0]).assign(1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:1011 _slice_helper
end.append(s + 1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper
return func(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1486 _add_dispatch
return gen_math_ops.add_v2(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:477 add_v2
x, y, name=name, ctx=_ctx)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:501 add_v2_eager_fallback
ctx=ctx, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
inputs, attrs, num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: while/Placeholder:0
环境:python3.6 tensorflow-gpu2.4.0 Ubuntu20.04
您可以尝试将 tf.tensor_scatter_nd_update
用于您的用例,但请注意,当前为每个数据点或一批数据点调用 main
方法,因此 new_data
被重新初始化每次。
import tensorflow as tf
import numpy as np
def body(index, new_data):
new_data = tf.tensor_scatter_nd_update(new_data, [[index, 0], [index, 0]], [1.0, 1.0])
return tf.add(index,1), new_data
def main(data):
new_data = tf.zeros((200,2), dtype=tf.float64)
index = tf.constant(0)
condition = lambda index, new_data: tf.less(index, 200)
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
return data
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
trainDS
.map(main, num_parallel_calls=tf.data.AUTOTUNE)
.batch(10, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
for i in trainDS:
print(i)
我想用for loop
来操作tf.data
中的一个数组。对于现在的tf.while_loop
方法,我必须要匹配输入参数和输出,所以我提前创建了一个数组new_data
,然后用tf.while_loop
顺序修改数组的内容,但是结果往往是错误的。我是不是做错了什么?
代码:
import tensorflow as tf
import numpy as np
def body(index, new_data):
tf.Variable(lambda:new_data[index, 0]).assign(1) #An error occurred
tf.Variable(lambda:new_data[index, 0]).assign(1)
return tf.add(index,1), new_data
def main(data):
new_data = tf.zeros((200,2), dtype=tf.float64)
index = tf.constant(0)
condition = lambda index, new_data: tf.less(index, 200)
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
return data
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
trainDS
.map(main, num_parallel_calls=tf.data.AUTOTUNE)
.batch(10, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
for i in trainDS:
i
错误:
test2.py:13 main *
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/deprecation.py:605 new_func **
return func(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2499 while_loop_v2
return_same_structure=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/control_flow_ops.py:2696 while_loop
back_prop=back_prop)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:200 while_loop
add_control_dependencies=add_control_dependencies)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py:990 func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/while_v2.py:178 wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
test2.py:5 body
tf.Variable(lambda:new_data[index, 0]).assign(1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:262 __call__
return cls._variable_v2_call(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:237 <lambda>
previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variable_scope.py:2667 default_variable_creator_v2
shape=shape)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/variables.py:264 __call__
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1585 __init__
distribute_strategy=distribute_strategy)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:1712 _init_from_args
initial_value = initial_value()
test2.py:5 <lambda>
tf.Variable(lambda:new_data[index, 0]).assign(1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/array_ops.py:1011 _slice_helper
end.append(s + 1)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1180 binary_op_wrapper
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1164 binary_op_wrapper
return func(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/util/dispatch.py:201 wrapper
return target(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:1486 _add_dispatch
return gen_math_ops.add_v2(x, y, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:477 add_v2
x, y, name=name, ctx=_ctx)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_math_ops.py:501 add_v2_eager_fallback
ctx=ctx, name=name)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:75 quick_execute
raise e
/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py:60 quick_execute
inputs, attrs, num_outputs)
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
@tf.function
def has_init_scope():
my_constant = tf.constant(1.)
with tf.init_scope():
added = my_constant * 2
The graph tensor has name: while/Placeholder:0
环境:python3.6 tensorflow-gpu2.4.0 Ubuntu20.04
您可以尝试将 tf.tensor_scatter_nd_update
用于您的用例,但请注意,当前为每个数据点或一批数据点调用 main
方法,因此 new_data
被重新初始化每次。
import tensorflow as tf
import numpy as np
def body(index, new_data):
new_data = tf.tensor_scatter_nd_update(new_data, [[index, 0], [index, 0]], [1.0, 1.0])
return tf.add(index,1), new_data
def main(data):
new_data = tf.zeros((200,2), dtype=tf.float64)
index = tf.constant(0)
condition = lambda index, new_data: tf.less(index, 200)
r = tf.while_loop(condition, body, loop_vars=(index, new_data))
return data
trainDS = tf.data.Dataset.from_tensor_slices((np.arange(100)))
trainDS = (
trainDS
.map(main, num_parallel_calls=tf.data.AUTOTUNE)
.batch(10, drop_remainder=True)
.prefetch(tf.data.AUTOTUNE))
for i in trainDS:
print(i)