没有急切执行的张量流中的经常性损失
Recurrent loss in tensorflow without executing eagerly
我有以下非常简单的损失示例(可能没有意义)
import tensorflow as tf
class Loss:
def __init__(self):
self.last_output = tf.constant([0.5,0.5])
def recurrent_loss(self, model_output):
now = 0.9*self.last_output + 0.1*model_output
self.last_output = now
return tf.reduce_mean(now)
仅计算 model_output 的 reduced_mean
与 最后一个 model_output 的组合(比率为 9 比 1)。例如
>> l = Loss()
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.55, shape=(), dtype=float32)
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.595, shape=(), dtype=float32)
如果我正确理解 tf 是如何工作的,那么这是唯一可能的,因为默认情况下 tf 正在急切地执行 (tf.executing_eagerly() == True
)。这应该是我可以用新张量覆盖 self.last_output 变量以实现循环结构的原因。
我的问题:如何在不使用急切执行的 tf 图中实现相同类型的循环结构?
在图形模式下,您必须使用仅在第一次执行函数时创建的 tf.Variable,例如:
class Loss:
def __init__(self):
self.last_output = None
@tf.function
def recurrent_loss(self, model_output):
if self.last_output is None:
self.last_output = tf.Variable([0.5,0.5])
now = 0.9*self.last_output + 0.1*model_output
self.last_output.assign(now)
return tf.reduce_mean(now)
我有以下非常简单的损失示例(可能没有意义)
import tensorflow as tf
class Loss:
def __init__(self):
self.last_output = tf.constant([0.5,0.5])
def recurrent_loss(self, model_output):
now = 0.9*self.last_output + 0.1*model_output
self.last_output = now
return tf.reduce_mean(now)
仅计算 model_output 的 reduced_mean
与 最后一个 model_output 的组合(比率为 9 比 1)。例如
>> l = Loss()
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.55, shape=(), dtype=float32)
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.595, shape=(), dtype=float32)
如果我正确理解 tf 是如何工作的,那么这是唯一可能的,因为默认情况下 tf 正在急切地执行 (tf.executing_eagerly() == True
)。这应该是我可以用新张量覆盖 self.last_output 变量以实现循环结构的原因。
我的问题:如何在不使用急切执行的 tf 图中实现相同类型的循环结构?
在图形模式下,您必须使用仅在第一次执行函数时创建的 tf.Variable,例如:
class Loss:
def __init__(self):
self.last_output = None
@tf.function
def recurrent_loss(self, model_output):
if self.last_output is None:
self.last_output = tf.Variable([0.5,0.5])
now = 0.9*self.last_output + 0.1*model_output
self.last_output.assign(now)
return tf.reduce_mean(now)