torch.einsum 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'

torch.einsum 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'

我正在尝试手动计算矩阵的梯度,我可以使用 numpy 来完成,但我不知道在 pytorch 中做同样的事情。 NumPy 中的等式是

def grad(A, W0, W1, X):
    dim = A.shape
    assert len(dim) == 2
    A_rows = dim[0]
    A_cols = dim[1]    
    gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
    return gradient

我在 pytorch 中写了一个函数,但它给我一个错误 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'

我用pytorch写的函数是

def torch_grad(A, W0, W1, X):
    dim = A.shape
    A_rows = dim[0]
    A_cols = dim[1]
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)
    print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
    e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
    e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
    return e1 + e2

如果有人能告诉我如何在 pytorch 中实现 numpy 代码,我将不胜感激。 谢谢!

您在第一个 torch.einsum 调用中缺少一个逗号。

e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))

除了打字错误,这不是关于 A 的梯度的计算方式,并且当 AAX 具有不同的大小时它会失败,否则对前传。对于 e1,这应该是一个矩阵乘法,也许这是你的意图,在这种情况下 torch.einsum 应该是 'ik, kl',但这只是执行矩阵乘法的一种过于复杂的方法使用 torch.mm 更简单、更高效。 e2 不涉及针对 A 执行的任何计算,因此它不是梯度的一部分。

def torch_grad(A, W0, W1, X):
    # Forward
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)

    # Backward / Gradients
    rows, cols = AXW0W1.size()
    grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
    grad_A = torch.mm(grad_AX, X.t())
    return grad_A

# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
    # Forward
    W0W1 = torch.mm(W0, W1)
    AX = torch.mm(A, X)
    AXW0W1 = torch.mm(AX, W0W1)
    XW0W1 = torch.mm(X, W0W1)
    
    # Backward / Gradients
    rows, cols = AXW0W1.size()
    AXW0W1.backward(torch.eye(rows, cols))
    return A.grad

# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)

grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)

torch.equal(grad_result, autograd_result) # => True