PyTorch 中的后向函数
Backward function in PyTorch
我对 pytorch 的后向函数有一些疑问,我认为我没有得到正确的输出:
import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True)
out = a * a
out.backward(a)
print(a.grad)
输出是
tensor([[ 2., 8., 18.],
[32., 50., 72.]])
也许是 2*a*a
但我认为输出应该是
tensor([[ 2., 4., 6.],
[8., 10., 12.]])
2*a.
因为 d(x^2)/dx=2x
请仔细阅读 backward()
上的文档以更好地理解它。
默认情况下,pytorch 期望为网络的 last 输出调用 backward()
- 损失函数。损失函数总是输出标量,因此,标量损失w.r.t所有其他variables/parameters的梯度定义明确(使用链式法则)。
因此,默认情况下,backward()
在标量张量上调用并且不需要任何参数。
例如:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
符合预期:d(a^2)/da = 2a
。
但是,当您在 2×3 out
张量(不再是标量函数)上调用 backward
时 - 您期望 a.grad
是什么?您实际上需要一个 2×3×2×3 输出:d out[i,j] / d a[k,l]
(!)
Pytorch 不支持这种非标量函数导数。相反,pytorch 假设 out
只是一个中间张量,并且“上游”某处有一个标量损失函数,通过链式法则提供 d loss/ d out[i,j]
。这个“上游”梯度的大小为 2×3,这实际上是您在这种情况下提供的参数 backward
:out.backward(g)
其中 g_ij = d loss/ d out_ij
.
然后用链式法则计算梯度d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
由于您提供了 a
作为“上游”梯度,您得到了
a.grad[i,j] = 2 * a[i,j] * a[i,j]
如果您要将“上游”梯度提供为全部
out.backward(torch.ones(2,3))
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
符合预期。
都是链式法则
我对 pytorch 的后向函数有一些疑问,我认为我没有得到正确的输出:
import numpy as np
import torch
from torch.autograd import Variable
a = Variable(torch.FloatTensor([[1,2,3],[4,5,6]]), requires_grad=True)
out = a * a
out.backward(a)
print(a.grad)
输出是
tensor([[ 2., 8., 18.],
[32., 50., 72.]])
也许是 2*a*a
但我认为输出应该是
tensor([[ 2., 4., 6.],
[8., 10., 12.]])
2*a.
因为 d(x^2)/dx=2x
请仔细阅读 backward()
上的文档以更好地理解它。
默认情况下,pytorch 期望为网络的 last 输出调用 backward()
- 损失函数。损失函数总是输出标量,因此,标量损失w.r.t所有其他variables/parameters的梯度定义明确(使用链式法则)。
因此,默认情况下,backward()
在标量张量上调用并且不需要任何参数。
例如:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
产量
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
符合预期:d(a^2)/da = 2a
。
但是,当您在 2×3 out
张量(不再是标量函数)上调用 backward
时 - 您期望 a.grad
是什么?您实际上需要一个 2×3×2×3 输出:d out[i,j] / d a[k,l]
(!)
Pytorch 不支持这种非标量函数导数。相反,pytorch 假设 out
只是一个中间张量,并且“上游”某处有一个标量损失函数,通过链式法则提供 d loss/ d out[i,j]
。这个“上游”梯度的大小为 2×3,这实际上是您在这种情况下提供的参数 backward
:out.backward(g)
其中 g_ij = d loss/ d out_ij
.
然后用链式法则计算梯度d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
由于您提供了 a
作为“上游”梯度,您得到了
a.grad[i,j] = 2 * a[i,j] * a[i,j]
如果您要将“上游”梯度提供为全部
out.backward(torch.ones(2,3))
print(a.grad)
产量
tensor([[ 2., 4., 6.], [ 8., 10., 12.]])
符合预期。
都是链式法则