使用 tensorflow 2.1.0 更新字典

Updating dictionary with tensorflow 2.1.0

我遇到了一个问题,我似乎无法在其他任何地方找到解决方案,所以我决定 post 我的问题在这里(我有 tensorflow 的基本知识,但很新):

我在python中写了一个简单的代码来说明我想做什么。

    import tensorflow as tf

    def generated_dict():
       graph = {'input': tf.Variable(2)}
       graph['layer_1'] = tf.square(graph['input'])
       graph['layer_2'] = tf.add(graph['input'], graph['layer_1'])
       return graph

    graph = generated_dict()
    print("boo = " + str(graph['layer_2']))

    graph['input'].assign(tf.constant(3))
    print("far = " + str(graph['layer_2']))

在这个示例代码中,当我通过 graph['input'].assign(tf.constant(3)) 分配新的输入值时,我希望 tensorflow 更新整个字典。基本上,现在我得到

boo = tf.Tensor(6, shape=(), dtype=int32) # 2²+2
far = tf.Tensor(6, shape=(), dtype=int32) # 2²+2

这是正常的,因为我的代码急于执行。但是我希望字典用我的新输入更新它的值并得到:

boo = tf.Tensor(6, shape=(), dtype=int32) #2²+2
far = tf.Tensor(12, shape=(), dtype=int32) #3²+3

我觉得我应该使用 tf.function() 但我不确定我应该如何处理它。我试过 graph = tf.function(generated_graph)() 但我没有帮助。 任何帮助将不胜感激。

首先,在尝试将 TF1 代码改编为 TF2 时,我建议阅读指南: Migrate your TensorFlow 1 code to TensorFlow 2.

TF2 将 tensorflow 的基本设计从图中的 运行 ops 更改为 eager execution。这意味着在 TF2 中编写代码与在普通 python 中编写代码非常接近:所有抽象、图形创建等都在幕后完成。

你复杂的设计不需要存在于TF2中,写一个简单的python功能即可。您甚至可以使用普通运算符代替张量流函数。那里会转换成tensorflow算子。或者,您可以使用 tf.function decorator for performances.

@tf.function
def my_graph(x):
     return x**2 + x

现在,如果你想在该函数中提供数据,你只需要用一个值调用它:

>>> my_graph(tf.constant(3))
<tf.Tensor: shape=(), dtype=int32, numpy=12>
>>> my_graph(tf.constant(2))
<tf.Tensor: shape=(), dtype=int32, numpy=6>