y = x / sum(x, dim=0) 的反向传播,其中张量 x 的大小为 (H,W)

Back-Propagation of y = x / sum(x, dim=0) where size of tensor x is (H,W)

Q1.

我正在尝试使用 pytorch 制作我的自定义 autograd 函数。

但是我在使用 y = x / sum(x, dim=0)

进行解析反向传播时遇到了问题

其中张量 x 的大小为(高度,宽度)(x 是二维的)。

这是我的代码

class MyFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
  ctx.save_for_backward(input)
  input = input / torch.sum(input, dim=0)

  return input

@staticmethod
def backward(ctx, grad_output):
  input = ctx.saved_tensors[0]
  H, W = input.size()
  sum = torch.sum(input, dim=0)
  grad_input = grad_output * (1/sum - input*1/sum**2)

  return grad_input

我使用(torch.autograd导入)gradcheck来比较雅可比矩阵,

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.randn(3,3,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)

结果是

请有人帮助我获得正确的反向传播结果

谢谢!


Q2.

感谢解答!

因为你的帮助,我可以在 (H,W) 张量的情况下实现反向传播。

然而,当我在 (N,H,W) 张量的情况下实现反向传播时,我遇到了问题。 我认为问题在于初始化新张量。

这是我的新密码

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyFunc(torch.autograd.Function):
  @staticmethod
  def forward(ctx, input):
    ctx.save_for_backward(input)
    
    N = input.size(0)
    for n in range(N):
      input[n] /= torch.sum(input[n], dim=0)

    return input

  @staticmethod
  def backward(ctx, grad_output):
    input = ctx.saved_tensors[0]
    N, H, W = input.size()
    I = torch.eye(H).unsqueeze(-1)
    sum = input.sum(1)

    grad_input = torch.zeros((N,H,W), dtype = torch.double, requires_grad=True)
    for n in range(N):
      grad_input[n] = ((sum[n] * I - input[n]) * grad_output[n] / sum[n]**2).sum(1)

    return grad_input

毕业检查代码是

from torch.autograd import gradcheck
func = MyFunc.apply
input = (torch.rand(2,2,2,dtype=torch.double,requires_grad=True))
test = gradcheck(func, input)
print(test)

结果是 enter image description here

不知道为什么会报错...

你的帮助对我实现自己的卷积网络很有帮助。

谢谢!祝你有个愉快的一天。

你的雅可比行列式不准确:它是一个 4d 张量,你只计算了它的一个 2D 切片。

你忽略了雅可比矩阵的第二行:

让我们看一个单列的例子,例如:[[x1], [x2], [x3]].

sumx1 + x2 + x3,然后归一化x将得到y = [[y1], [y2], [y3]] = [[x1/sum], [x2/sum], [x3/sum]]。您正在寻找 dL/dx1dL/x2dL/x3 - 我们将把它们写成:dx1dx2dx3 .所有 dL/dyi.

都一样

所以 dx1 等于 dL/dy1*dy1/dx1 + dL/dy2*dy2/dx1 + dL/dy3*dy3/dx1。这是因为 x1 对相应列上的所有输出元素都有贡献:y1y2y3.

我们有:

  • dy1/dx1 = d(x1/sum)/dx1 = (sum - x1)/sum²

  • dy2/dx1 = d(x2/sum)/dx1 = -x2/sum²

  • 类似地,dy3/dx1 = d(x3/sum)/dx1 = -x3/sum²

因此dx1 = (sum - x1)/sum²*dy1 - x2/sum²*dy2 - x3/sum²*dy3dx2dx3 相同。因此,雅可比行列式是 [dxi]_i = (sum - xi)/sum²[dxi]_j = -xj/sum²(所有 j 不同于 i)。

在您的实现中,您似乎缺少所有非对角线分量。

保持相同的单列示例,x1=2x2=3x3=5:

>>> x = torch.tensor([[2.], [3.], [5.]])

>>> sum = input.sum(0)
tensor([10])

雅可比矩阵为:

>>> J = (sum*torch.eye(input.size(0)) - input)/sum**2
tensor([[ 0.0800, -0.0200, -0.0200],
        [-0.0300,  0.0700, -0.0300],
        [-0.0500, -0.0500,  0.0500]])

对于多列的实现,它有点棘手,更具体地说是对角矩阵的形状。将 轴放在最后会更容易,这样我们就不必为广播而烦恼:

>>> x = torch.tensor([[2., 1], [3., 3], [5., 5]])
>>> sum = x.sum(0)
tensor([10.,  9.])

>>> diag = sum*torch.eye(3).unsqueeze(-1).repeat(1, 1, len(sum))
tensor([[[10.,  9.],
         [ 0.,  0.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [10.,  9.],
         [ 0.,  0.]],

        [[ 0.,  0.],
         [ 0.,  0.],
         [10.,  9.]]])

上方 diag 的形状为 (3, 3, 2),其中两个 位于最后一个轴上。注意我们不需要广播 sum.

不会做的是:torch.eye(3).unsqueeze(0).repeat(len(sum), 1, 1)。由于使用这种形状 - (2, 3, 3) - 你将不得不使用 sum[:, None, None],并且需要进一步广播...

雅可比矩阵很简单:

>>> J = (diag - x)/sum**2
tensor([[[ 0.0800,  0.0988],
         [-0.0300, -0.0370],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [ 0.0700,  0.0741],
         [-0.0500, -0.0617]],

        [[-0.0200, -0.0123],
         [-0.0300, -0.0370],
         [ 0.0500,  0.0494]]])

您可以通过使用任意 dy 向量反向传播操作来检查结果(但不使用 torch.ones,因为 [=54=,您将得到 0s ]!)。反向传播后,x.grad 应等于 torch.einsum('abc,bc->ac', J, dy).

Q2 的答案。

我自己为许多批处理案例实现了反向传播。 我使用了 unsqueeze 功能,它起作用了。

输入大小:(N,H,W)(N 是批量大小)

forward:
  out = input / torch.sum(input, dim=1).unsqueeze(1)

backward:
  diag = torch.eye(input.size(1),  dtype=torch.double, requires_grad=True).unsqueeze(-1)
  sum = input.sum(1)
  grad_input = ((sum.unsqueeze(1).unsqueeze(1) * diag - input.unsqueeze(1)) * grad_out.unsqueeze(1) / (sum**2).unsqueeze(1).unsqueeze(1)).sum(2)