我如何在 TF 2.0 中计算这个梯度?

How do I compute this gradient in TF 2.0?

我提供了一个我想要解决的问题的最小示例。我定义了一个 class 并且其中有一些跨不同函数定义的变量。我想知道如何跨函数跟踪这些变量以获得梯度。我想我必须使用 tf.GradientTape 但我尝试了一些变体但没有成功。

class A():
     def __init__(self):
         self.alpha = tf.Variable(2.0)
     def f1(self):
         wt = self.alpha * 5.0
         return wt
     def f2(self):
         wt_f1 = f1()
         with tf.GradientTape() as tape:
            wt_f2 = wt_f1 * 10.0
            print(tape.gradient(wt_f2, self.alpha))

a = A()
print(a.f2())

最后一行returnsNone。显然 wt_f2alpha 的导数是 50.0。但是,我得到 None。任何想法?我尝试在 __init__ 函数中初始化持久梯度带,并使用它来观察 wtself.alpha 等变量,但这没有帮助。有什么想法吗?

更新 1:

wt_f1 调用置于 tape 下无效。

class A():
     def __init__(self):
         self.alpha = tf.Variable(2.0)
     def f1(self):
         wt = self.alpha * 5.0
         return wt
     def f2(self):
         with tf.GradientTape() as tape:
            wt_f1 = f1()
            wt_f2 = wt_f1 * 10.0
            print(tape.gradient(wt_f2, self.alpha))

这也是returnsNone

您正在打印 None。因为f2()returns什么都没有,所以得到None。 删除打印:

a = A()
a.f2()

此外,一些编辑可能对您编写的代码有益。

  1. 您在 f1() 函数之前错过了 self 并且这有效,因为您在其他地方定义了 f1 函数。无论如何添加 self.f1().
  2. print 语句移出 tape 范围。因为最好在录制结束的地方做渐变。
  3. 添加 tape.watch() 以确保它被磁带跟踪。
class A():
    def __init__(self):
        self.alpha = tf.Variable(2.0)
    def f1(self):
        wt = self.alpha * 5.0
        return wt
    def f2(self):
        with tf.GradientTape() as tape:
            tape.watch(self.alpha)
            wt_f1 = self.f1()
            wt_f2 = wt_f1 * 10.0
        print(tape.gradient(wt_f2, self.alpha))