通过名称分配张量的值
Assigning the value of a tensor by its name
我正在使用以下代码在 TensorFlow 图中获取张量:
names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)
如何设置 tensor
的值?
大多数 TensorFlow 张量(tf.Tensor
个对象)是 不可变的 ,因此您不能简单地为它们赋值。但是,如果您将张量创建为 tf.Variable
, you can assign a value to it by calling Variable.assign()
。
您的代码不必要地将 tf.Variable
对象(来自 tf.trainable_variables()
的列表)转换为字符串名称。相反,您可以执行以下操作:
# Get the 0th trainable variable.
var = tf.trainable_variables()[0]
# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])
# Actually run the assignment.
sess.run(assign_op)
但是,根据你的评论,你有多个图(即self.graph
和graph
是不同的),所以我上面写的一般解决方案不会工作。在这种情况下,您有两个选择:
在另一个图中按名称获取变量(N.B。 这仅在 graph_2.get_collection('trainable_variables')
已填充时有效;如果您使用 tf.import_graph_def()
来构建图形,它将不起作用):
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = [v for v in graph_2.get_collection('trainable_variables')
if v.name == var_name][0]
assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)
在另一个图中按名称获取张量,并使用tf.assign()
:
var_name = graph_1.get_collection('trainable_variables')[0].name
var_in_g2 = graph_2.get_tensor_by_name(var_name)
assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]])
sess.run(assign_op)
我正在使用以下代码在 TensorFlow 图中获取张量:
names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)
如何设置 tensor
的值?
大多数 TensorFlow 张量(tf.Tensor
个对象)是 不可变的 ,因此您不能简单地为它们赋值。但是,如果您将张量创建为 tf.Variable
, you can assign a value to it by calling Variable.assign()
。
您的代码不必要地将 tf.Variable
对象(来自 tf.trainable_variables()
的列表)转换为字符串名称。相反,您可以执行以下操作:
# Get the 0th trainable variable.
var = tf.trainable_variables()[0]
# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])
# Actually run the assignment.
sess.run(assign_op)
但是,根据你的评论,你有多个图(即self.graph
和graph
是不同的),所以我上面写的一般解决方案不会工作。在这种情况下,您有两个选择:
在另一个图中按名称获取变量(N.B。 这仅在
graph_2.get_collection('trainable_variables')
已填充时有效;如果您使用tf.import_graph_def()
来构建图形,它将不起作用):var_name = graph_1.get_collection('trainable_variables')[0].name var_in_g2 = [v for v in graph_2.get_collection('trainable_variables') if v.name == var_name][0] assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]]) sess.run(assign_op)
在另一个图中按名称获取张量,并使用
tf.assign()
:var_name = graph_1.get_collection('trainable_variables')[0].name var_in_g2 = graph_2.get_tensor_by_name(var_name) assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]]) sess.run(assign_op)