tf.GradientTape() 的 __exit__ 函数的参数是什么?

What are the parameters to tf.GradientTape()'s __exit__ function?

根据 documentation for tf.GradientTape,其 __exit__() 方法采用三个位置参数:typ, value, traceback.

这些参数到底是什么?

with 语句如何推断它们?

我应该在下面的代码中给它们什么值(我 而不是 使用 with 语句):

x = tf.Variable(5)

gt = tf.GradientTape()
gt.__enter__()
y = x ** 2
gt.__exit__(typ = __, value = __, traceback = __)

sys.exc_info() returns 具有三个值的元组 (type, value, traceback).

  1. 这里type获取正在处理的Exception的异常类型
  2. value 是传递给异常 class.
  3. 构造函数的参数
  4. traceback 包含异常发生位置等堆栈信息

在 GradientTape 上下文中发生异常时 sys.exc_info() 详细信息将传递给 exit() 函数,该函数将 Exits the recording context, no further operations are traced

下面是说明相同的示例。

让我们考虑一个简单的函数。

def f(w1, w2):
    return 3 * w1 ** 2 + 2 * w1 * w2

不使用 with 语句:

w1, w2 = tf.Variable(5.), tf.Variable(3.)

tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)
    exec_tup = sys.exc_info()
    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

打印:

GradientTape.gradient can only be called once on non-persistent tapes.

即使你不显式传值退出,程序也会传这些值退出GradientTaoe记录,下面是例子。

w1, w2 = tf.Variable(5.), tf.Variable(3.)

tape = tf.GradientTape()
z = f(w1, w2)
tape.__enter__()
dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)

打印相同的异常消息。

通过使用with语句。

with tf.GradientTape() as tape:
    z = f(w1, w2)

dz_dw1 = tape.gradient(z, w1)
try:
    dz_dw2 = tape.gradient(z, w2)
except Exception as ex:
    print(ex)
    exec_tup = sys.exc_info()
    tape.__exit__(exec_tup[0],exec_tup[1],exec_tup[2])

下面是对上述异常的 sys.exc_info() 响应。

(RuntimeError,
 RuntimeError('GradientTape.gradient can only be called once on non-persistent tapes.'),
 <traceback at 0x7fcd42dd4208>)

编辑 1:

如评论中user2357112 supports Monica所述。为非异常情况提供解决方案。

在无例外的情况下,规范要求传递给 __exit__ 的值都应该是 None.

示例 1:

x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z  = x*x
dy_dx = g.gradient(y, x) 
# dz_dx = g.gradient(z, x) 
print(dy_dx)
# print(dz_dx)

打印:

tf.Tensor(6.0, shape=(), dtype=float32) 

由于 y__exit__ 之前被捕获,它 returns 渐变值。

示例 2:

x = tf.constant(3.0)
g = tf.GradientTape()
g.__enter__()
g.watch(x)
y = x * x
g.__exit__(None,None,None)
z  = x*x
# dy_dx = g.gradient(y, x) 
dz_dx = g.gradient(z, x) 
# print(dy_dx)
print(dz_dx)

打印:

None 

这是因为z是在__exit__之后捕获的,因此渐变停止记录。