我如何在 TF 2.0 中计算这个梯度?
How do I compute this gradient in TF 2.0?
我提供了一个我想要解决的问题的最小示例。我定义了一个 class 并且其中有一些跨不同函数定义的变量。我想知道如何跨函数跟踪这些变量以获得梯度。我想我必须使用 tf.GradientTape
但我尝试了一些变体但没有成功。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
wt_f1 = f1()
with tf.GradientTape() as tape:
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
a = A()
print(a.f2())
最后一行returnsNone
。显然 wt_f2
对 alpha
的导数是 50.0。但是,我得到 None
。任何想法?我尝试在 __init__
函数中初始化持久梯度带,并使用它来观察 wt
和 self.alpha
等变量,但这没有帮助。有什么想法吗?
更新 1:
将 wt_f1
调用置于 tape
下无效。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
wt_f1 = f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
这也是returnsNone
。
您正在打印 None。因为f2()
returns什么都没有,所以得到None
。
删除打印:
a = A()
a.f2()
此外,一些编辑可能对您编写的代码有益。
- 您在
f1()
函数之前错过了 self
并且这有效,因为您在其他地方定义了 f1
函数。无论如何添加 self.f1()
.
- 将
print
语句移出 tape
范围。因为最好在录制结束的地方做渐变。
- 添加
tape.watch()
以确保它被磁带跟踪。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
tape.watch(self.alpha)
wt_f1 = self.f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
我提供了一个我想要解决的问题的最小示例。我定义了一个 class 并且其中有一些跨不同函数定义的变量。我想知道如何跨函数跟踪这些变量以获得梯度。我想我必须使用 tf.GradientTape
但我尝试了一些变体但没有成功。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
wt_f1 = f1()
with tf.GradientTape() as tape:
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
a = A()
print(a.f2())
最后一行returnsNone
。显然 wt_f2
对 alpha
的导数是 50.0。但是,我得到 None
。任何想法?我尝试在 __init__
函数中初始化持久梯度带,并使用它来观察 wt
和 self.alpha
等变量,但这没有帮助。有什么想法吗?
更新 1:
将 wt_f1
调用置于 tape
下无效。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
wt_f1 = f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))
这也是returnsNone
。
您正在打印 None。因为f2()
returns什么都没有,所以得到None
。
删除打印:
a = A()
a.f2()
此外,一些编辑可能对您编写的代码有益。
- 您在
f1()
函数之前错过了self
并且这有效,因为您在其他地方定义了f1
函数。无论如何添加self.f1()
. - 将
print
语句移出tape
范围。因为最好在录制结束的地方做渐变。 - 添加
tape.watch()
以确保它被磁带跟踪。
class A():
def __init__(self):
self.alpha = tf.Variable(2.0)
def f1(self):
wt = self.alpha * 5.0
return wt
def f2(self):
with tf.GradientTape() as tape:
tape.watch(self.alpha)
wt_f1 = self.f1()
wt_f2 = wt_f1 * 10.0
print(tape.gradient(wt_f2, self.alpha))