当 `x` 已经是 TensorFlow 中的 `tf.Variable` 时调用 `tape.watch(x)` 可以吗?

Is it ok to call `tape.watch(x)` when `x` is already a `tf.Variable` in TensorFlow?

考虑以下函数

def foo(x):
  with tf.GradientTape() as tape:
    tape.watch(x)

    y = x**2 + x + 4

  return tape.gradient(y, x)

调用函数foo(tf.constant(3.14))时需要调用tape.watch(x),直接传入变量时则不需要调用tape.watch(x),如foo(tf.Variable(3.14)).

现在我的问题是,即使在直接传入 tf.Variable 的情况下,对 tape.watch(x) 的调用是否安全?还是会因为变量已经被自动监视然后再次手动监视而发生一些奇怪的事情?编写这样可以同时接受 tf.Tensortf.Variable 的通用函数的正确方法是什么?

应该是安全的。一方面,tf.GradientTape.watch 的文档说:

Ensures that tensor is being traced by this tape.

"Ensures" 似乎暗示它将确保它被跟踪以防万一。事实上,文档没有给出任何迹象表明在同一个对象上使用它两次应该是一个问题(尽管如果他们明确说明它不会受到伤害)。

但无论如何,我们可以深入源代码进行检查。最后,在变量上调用 watch(如果它不是变量但路径略有不同,答案最终相同)归结为 GradientTape [=] 的 WatchVariable 方法42=] 在 C++ 中:

void WatchVariable(PyObject* v) {
  tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
  if (handle == nullptr) {
    return;
  }
  tensorflow::int64 id = FastTensorId(handle.get());

  if (!PyErr_Occurred()) {
    this->Watch(id);
  }

  tensorflow::mutex_lock l(watched_variables_mu_);
  auto insert_result = watched_variables_.emplace(id, v);

  if (insert_result.second) {
    // Only increment the reference count if we aren't already watching this
    // variable.
    Py_INCREF(v);
  }
}

方法的后半部分显示被监视的变量被添加到watched_variables_,这是一个std::set,所以再次添加一些东西不会做任何事情。这实际上是稍后检查以确保 Python 引用计数是正确的。前半部分基本调用Watch:

template <typename Gradient, typename BackwardFunction, typename TapeTensor>
void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
    int64 tensor_id) {
  tensor_tape_.emplace(tensor_id, -1);
}

tensor_tape_ 是一个映射(特别是 tensorflow::gtl:FlatMap,与标准 C++ 映射几乎相同),因此如果 tensor_id 已经存在,这将没有任何效果。

因此,即使没有明确说明,一切都表明它应该没有问题。

它被设计成供变量使用。来自docs

By default GradientTape will automatically watch any trainable variables that are accessed inside the context. If you want fine grained control over which variables are watched you can disable automatic tracking by passing watch_accessed_variables=False to the tape constructor:

with tf.GradientTape(watch_accessed_variables=False) as tape:
  tape.watch(variable_a)
  y = variable_a ** 2  # Gradients will be available for `variable_a`.
  z = variable_b ** 3  # No gradients will be available since `variable_b` is
                       # not being watched.