通过名称分配张量的值

Assigning the value of a tensor by its name

我正在使用以下代码在 TensorFlow 图中获取张量:

names = [var.name for var in self.graph.get_collection('trainable_variables')]
tensor = graph.get_tensor_by_name(names[0])
sess.run(tensor)

如何设置 tensor 的值?

大多数 TensorFlow 张量(tf.Tensor 个对象)是 不可变的 ,因此您不能简单地为它们赋值。但是,如果您将张量创建为 tf.Variable, you can assign a value to it by calling Variable.assign()

您的代码不必要地将 tf.Variable 对象(来自 tf.trainable_variables() 的列表)转换为字符串名称。相反,您可以执行以下操作:

# Get the 0th trainable variable.
var = tf.trainable_variables()[0]

# Create an op to assign a new value.
assign_op = var.assign([[1.0, 2.0], [3.0, 4.0]])

# Actually run the assignment.
sess.run(assign_op)

但是,根据你的评论,你有多个图(即self.graphgraph是不同的),所以我上面写的一般解决方案不会工作。在这种情况下,您有两个选择:

  1. 在另一个图中按名称获取变量(N.B。 这仅在 graph_2.get_collection('trainable_variables') 已填充时有效;如果您使用 tf.import_graph_def() 来构建图形,它将不起作用):

    var_name = graph_1.get_collection('trainable_variables')[0].name
    
    var_in_g2 = [v for v in graph_2.get_collection('trainable_variables')
                 if v.name == var_name][0]
    
    assign_op = var_in_g2.assign([[1.0, 2.0], [3.0, 4.0]])
    sess.run(assign_op)
    
  2. 在另一个图中按名称获取张量,并使用tf.assign():

    var_name = graph_1.get_collection('trainable_variables')[0].name
    
    var_in_g2 = graph_2.get_tensor_by_name(var_name)
    
    assign_op = tf.assign(var_in_g2, [[1.0, 2.0], [3.0, 4.0]])
    sess.run(assign_op)