Tensorflow 函数不会改变属性的属性

Tensorflow function doesn't change attribute's attribute

Tf 函数不改变对象的属性

class f:
    v = 7
    def __call__(self):
        self.v = self.v + 1

@tf.function
def call(c):
    tf.print(c.v)  # always 7
    c()
    tf.print(c.v)  # always 8

c = f()
call(c)
call(c)

预期打印: 7 8个 8个 9

而是: 7 8个 7 8

当我删除 @tf.function 装饰器时,一切都按预期工作。如何使用@tf.function

使我的函数按预期工作

此行为已记录 here

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a Function, sometimes executing twice or not all. They only happen the first time you call a Function with a set of inputs. Afterwards, the traced tf.Graph is reexecuted, without executing the Python code.The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like tf.data, tf.print, tf.summary, tf.Variable.assign, and tf.TensorArray are the best way to ensure your code will be executed by the TensorFlow runtime with each call.

所以,也许尝试使用 tf.Variable 来查看预期的变化:

import tensorflow as tf
class f:
    v = tf.Variable(7)
    def __call__(self):
      self.v.assign_add(1)

@tf.function
def call(c):
    tf.print(c.v)  # always 7
    c()
    tf.print(c.v)  # always 8

c = f()
call(c)
call(c)
7
8
8
9