使用 tensorflow 2.1.0 更新字典
Updating dictionary with tensorflow 2.1.0
我遇到了一个问题,我似乎无法在其他任何地方找到解决方案,所以我决定 post 我的问题在这里(我有 tensorflow 的基本知识,但很新):
我在python中写了一个简单的代码来说明我想做什么。
import tensorflow as tf
def generated_dict():
graph = {'input': tf.Variable(2)}
graph['layer_1'] = tf.square(graph['input'])
graph['layer_2'] = tf.add(graph['input'], graph['layer_1'])
return graph
graph = generated_dict()
print("boo = " + str(graph['layer_2']))
graph['input'].assign(tf.constant(3))
print("far = " + str(graph['layer_2']))
在这个示例代码中,当我通过 graph['input'].assign(tf.constant(3))
分配新的输入值时,我希望 tensorflow 更新整个字典。基本上,现在我得到
boo = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
far = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
这是正常的,因为我的代码急于执行。但是我希望字典用我的新输入更新它的值并得到:
boo = tf.Tensor(6, shape=(), dtype=int32) #2²+2
far = tf.Tensor(12, shape=(), dtype=int32) #3²+3
我觉得我应该使用 tf.function() 但我不确定我应该如何处理它。我试过 graph = tf.function(generated_graph)()
但我没有帮助。
任何帮助将不胜感激。
首先,在尝试将 TF1 代码改编为 TF2 时,我建议阅读指南:
Migrate your TensorFlow 1 code to TensorFlow 2.
TF2 将 tensorflow 的基本设计从图中的 运行 ops 更改为 eager execution。这意味着在 TF2 中编写代码与在普通 python 中编写代码非常接近:所有抽象、图形创建等都在幕后完成。
你复杂的设计不需要存在于TF2中,写一个简单的python功能即可。您甚至可以使用普通运算符代替张量流函数。那里会转换成tensorflow算子。或者,您可以使用 tf.function
decorator for performances.
@tf.function
def my_graph(x):
return x**2 + x
现在,如果你想在该函数中提供数据,你只需要用一个值调用它:
>>> my_graph(tf.constant(3))
<tf.Tensor: shape=(), dtype=int32, numpy=12>
>>> my_graph(tf.constant(2))
<tf.Tensor: shape=(), dtype=int32, numpy=6>
我遇到了一个问题,我似乎无法在其他任何地方找到解决方案,所以我决定 post 我的问题在这里(我有 tensorflow 的基本知识,但很新):
我在python中写了一个简单的代码来说明我想做什么。
import tensorflow as tf
def generated_dict():
graph = {'input': tf.Variable(2)}
graph['layer_1'] = tf.square(graph['input'])
graph['layer_2'] = tf.add(graph['input'], graph['layer_1'])
return graph
graph = generated_dict()
print("boo = " + str(graph['layer_2']))
graph['input'].assign(tf.constant(3))
print("far = " + str(graph['layer_2']))
在这个示例代码中,当我通过 graph['input'].assign(tf.constant(3))
分配新的输入值时,我希望 tensorflow 更新整个字典。基本上,现在我得到
boo = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
far = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
这是正常的,因为我的代码急于执行。但是我希望字典用我的新输入更新它的值并得到:
boo = tf.Tensor(6, shape=(), dtype=int32) #2²+2
far = tf.Tensor(12, shape=(), dtype=int32) #3²+3
我觉得我应该使用 tf.function() 但我不确定我应该如何处理它。我试过 graph = tf.function(generated_graph)()
但我没有帮助。
任何帮助将不胜感激。
首先,在尝试将 TF1 代码改编为 TF2 时,我建议阅读指南: Migrate your TensorFlow 1 code to TensorFlow 2.
TF2 将 tensorflow 的基本设计从图中的 运行 ops 更改为 eager execution。这意味着在 TF2 中编写代码与在普通 python 中编写代码非常接近:所有抽象、图形创建等都在幕后完成。
你复杂的设计不需要存在于TF2中,写一个简单的python功能即可。您甚至可以使用普通运算符代替张量流函数。那里会转换成tensorflow算子。或者,您可以使用 tf.function
decorator for performances.
@tf.function
def my_graph(x):
return x**2 + x
现在,如果你想在该函数中提供数据,你只需要用一个值调用它:
>>> my_graph(tf.constant(3))
<tf.Tensor: shape=(), dtype=int32, numpy=12>
>>> my_graph(tf.constant(2))
<tf.Tensor: shape=(), dtype=int32, numpy=6>