以下 Pytorch 结果背后的解释

Explanation behind the following Pytorch results

我想更深入地了解 Pytorch 的 autograd 是如何工作的。我无法解释以下结果:

import torch
def fn(a):
 b = torch.tensor(5,dtype=torch.float32,requires_grad=True)
 return a*b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

输出为张量(5.)。但我的问题是变量 b 是在函数内创建的,因此应该在函数 returns a*b 之后从内存中删除,对吗?因此,当我向后调用时,b 的值如何仍然存在以允许此计算? 据我所知,Pytorch 中的每个操作都有一个上下文变量,它跟踪“哪个”张量用于向后计算,并且每个张量中也存在版本,如果版本发生变化,那么向后应该会引发错误吗?

现在,当我尝试 运行 以下代码时,

import torch
def fn(a):
 b = a**2
 for i in range(5):
   b *= b
 return b 

a  = torch.tensor(10,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)

我收到以下错误:梯度计算所需的变量之一已被就地操作修改:[torch.FloatTensor []],即 MulBackward0 的输出 0,版本为 5;预期版本 4 相反。提示:启用异常检测以找到未能计算其梯度的操作,其中 torch.autograd.set_detect_anomaly(True).

但是如果我运行下面的代码,是没有错误的:

import torch
def fn(a):
  b = a**2
  for i in range(2):
    b = b*b
  return b

def fn2(a):
  b = a**2
  c = a**2
  for i in range(2):
    c *= b
  return c

a  = torch.tensor(5,dtype=torch.float32,requires_grad=True)
output = fn(a)
output.backward()
print(a.grad)
output2 = fn2(a)
output2.backward()
print(a.grad)

这个输出是:

张量(625000.)

张量(643750.)

因此对于具有相当多变量的标准计算图,在同一个函数中,我能够理解计算图的工作原理。但是当在调用向后函数之前有一个变量发生变化时,我在理解结果时遇到了很多麻烦。有人可以解释一下吗?

请注意 b *=bb = b*b 不同。

这可能令人困惑,但底层操作各不相同。

b *=b 的情况下,发生了一个就地操作,它弄乱了梯度,因此 RuntimeError

b = b*b 的情况下,两个张量对象相乘,所得对象的名称为 b。因此当你这样运行时没有RuntimeError

这是一个关于基础 python 操作的 SO 问题:

第一种情况下的fn和第二种情况下的fn2有什么区别? c*=b 操作不会破坏从 cb 的图形链接。操作 c*=c 将无法通过操作连接两个张量。

好吧,我无法使用张量来展示它,因为它们会引发 RuntimeError。所以我会尝试使用 python list.

>>> x = [1,2]
>>> y = [3]
>>> id(x), id(y)
(140192646516680, 140192646927112)
>>>
>>> x += y
>>> x, y
([1, 2, 3], [3])
>>> id(x), id(y)
(140192646516680, 140192646927112)

请注意,没有创建新对象。所以不可能从 output 追踪到初始变量。我们无法区分 object_140192646516680 是输出还是输入。那么如何用它创建图表..

考虑以下替代情况:

>>> a = [1,2]
>>> b = [3]
>>>
>>> id(a), id(b)
(140192666168008, 140192666168264)
>>>
>>> a = a + b
>>> a, b
([1, 2, 3], [3])
>>> id(a), id(b)
(140192666168328, 140192666168264)
>>>

请注意,新列表 a 实际上是具有 id 140192666168328 的新对象。在这里我们可以追踪到 object_140192666168328 来自另外两个对象 object_140192666168008object_140192666168264 之间的 addition operation。因此,可以动态创建图形,并且可以将梯度从 output 传播回之前的层。