没有急切执行的张量流中的经常性损失

Recurrent loss in tensorflow without executing eagerly

我有以下非常简单的损失示例(可能没有意义)

import tensorflow as tf

class Loss:
  def __init__(self):
    self.last_output = tf.constant([0.5,0.5])

  def recurrent_loss(self, model_output):
    now = 0.9*self.last_output + 0.1*model_output
    self.last_output = now
    return tf.reduce_mean(now)

仅计算 model_output 的 reduced_mean 最后一个 model_output 的组合(比率为 9 比 1)。例如

>> l = Loss()
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.55, shape=(), dtype=float32)
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.595, shape=(), dtype=float32)

如果我正确理解 tf 是如何工作的,那么这是唯一可能的,因为默认情况下 tf 正在急切地执行 (tf.executing_eagerly() == True)。这应该是我可以用新张量覆盖 self.last_output 变量以实现循环结构的原因。

我的问题:如何在不使用急切执行的 tf 图中实现相同类型的循环结构?

在图形模式下,您必须使用仅在第一次执行函数时创建的 tf.Variable,例如:

class Loss:
  def __init__(self):
    self.last_output = None
  @tf.function
  def recurrent_loss(self, model_output):
    if  self.last_output is None:
        
        self.last_output = tf.Variable([0.5,0.5])
        
    now = 0.9*self.last_output + 0.1*model_output
    self.last_output.assign(now)
    return tf.reduce_mean(now)