torch.Tensor.backward() 如何运作?
How torch.Tensor.backward() works?
最近在研究Pytorch和封装的backward函数。
我知道如何使用它,但是当我尝试
x = Variable(2*torch.ones(2, 2), requires_grad=True)
x.backward(x)
print(x.grad)
我希望
tensor([[1., 1.],
[1., 1.]])
因为它是恒等函数。然而,它 returns
tensor([[2., 2.],
[2., 2.]]).
为什么会这样?
其实,这就是你要找的:
情况一:当z = 2*x**3 + x
import torch
from torch.autograd import Variable
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x*x*2+x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[25., 25.],
[25., 25.]])
情况2:当z = x*x
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[4., 4.],
[4., 4.]])
情况 3:当 z = x(您的情况)
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[1., 1.],
[1., 1.]])
要了解更多如何在 pytorch 中计算梯度,请查看 。
我想你误解了如何使用 tensor.backward()
。 backward()
里面的参数不是dy/dx.
的x
例如,如果通过某种操作从x得到y,那么y.backward(w)
,首先pytorch会得到l = dot(y,w)
,然后计算dl/dx
。
所以对于你的代码,l = 2x
首先由pytorch计算,然后dl/dx
是你的代码returns。
当你做y.backward(w)
时,如果y不是标量,就让backward()
的参数全为1;否则就没有参数。
最近在研究Pytorch和封装的backward函数。 我知道如何使用它,但是当我尝试
x = Variable(2*torch.ones(2, 2), requires_grad=True)
x.backward(x)
print(x.grad)
我希望
tensor([[1., 1.],
[1., 1.]])
因为它是恒等函数。然而,它 returns
tensor([[2., 2.],
[2., 2.]]).
为什么会这样?
其实,这就是你要找的:
情况一:当z = 2*x**3 + x
import torch
from torch.autograd import Variable
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x*x*2+x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[25., 25.],
[25., 25.]])
情况2:当z = x*x
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x*x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[4., 4.],
[4., 4.]])
情况 3:当 z = x(您的情况)
x = Variable(2*torch.ones(2, 2), requires_grad=True)
z = x
z.backward(torch.ones_like(z))
print(x.grad)
输出:
tensor([[1., 1.],
[1., 1.]])
要了解更多如何在 pytorch 中计算梯度,请查看
我想你误解了如何使用 tensor.backward()
。 backward()
里面的参数不是dy/dx.
例如,如果通过某种操作从x得到y,那么y.backward(w)
,首先pytorch会得到l = dot(y,w)
,然后计算dl/dx
。
所以对于你的代码,l = 2x
首先由pytorch计算,然后dl/dx
是你的代码returns。
当你做y.backward(w)
时,如果y不是标量,就让backward()
的参数全为1;否则就没有参数。