torch.einsum 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'
torch.einsum 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'
我正在尝试手动计算矩阵的梯度,我可以使用 numpy 来完成,但我不知道在 pytorch 中做同样的事情。
NumPy 中的等式是
def grad(A, W0, W1, X):
dim = A.shape
assert len(dim) == 2
A_rows = dim[0]
A_cols = dim[1]
gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
return gradient
我在 pytorch 中写了一个函数,但它给我一个错误 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'
我用pytorch写的函数是
def torch_grad(A, W0, W1, X):
dim = A.shape
A_rows = dim[0]
A_cols = dim[1]
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
return e1 + e2
如果有人能告诉我如何在 pytorch 中实现 numpy 代码,我将不胜感激。
谢谢!
您在第一个 torch.einsum
调用中缺少一个逗号。
e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
除了打字错误,这不是关于 A
的梯度的计算方式,并且当 A
和 AX
具有不同的大小时它会失败,否则对前传。对于 e1
,这应该是一个矩阵乘法,也许这是你的意图,在这种情况下 torch.einsum
应该是 'ik, kl'
,但这只是执行矩阵乘法的一种过于复杂的方法使用 torch.mm
更简单、更高效。 e2
不涉及针对 A
执行的任何计算,因此它不是梯度的一部分。
def torch_grad(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
grad_A = torch.mm(grad_AX, X.t())
return grad_A
# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
AXW0W1.backward(torch.eye(rows, cols))
return A.grad
# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)
grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)
torch.equal(grad_result, autograd_result) # => True
我正在尝试手动计算矩阵的梯度,我可以使用 numpy 来完成,但我不知道在 pytorch 中做同样的事情。 NumPy 中的等式是
def grad(A, W0, W1, X):
dim = A.shape
assert len(dim) == 2
A_rows = dim[0]
A_cols = dim[1]
gradient = (np.einsum('ik, jl', np.eye(A_cols, A_rows), (((A).dot(X)).dot(W0)).dot(W1).T) + np.einsum('ik, jl', A, ((X).dot(W0)).dot(W1).T))
return gradient
我在 pytorch 中写了一个函数,但它给我一个错误 'RuntimeError: dimension mismatch for operand 0: equation 4 tensor 2'
我用pytorch写的函数是
def torch_grad(A, W0, W1, X):
dim = A.shape
A_rows = dim[0]
A_cols = dim[1]
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
print(torch.eye(A_cols, A_rows).shape, torch.t(AXW0W1).shape)
e1 = torch.einsum('ik jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
e2 = torch.einsum('ik, jl', A, torch.t(XW0W1))
return e1 + e2
如果有人能告诉我如何在 pytorch 中实现 numpy 代码,我将不胜感激。 谢谢!
您在第一个 torch.einsum
调用中缺少一个逗号。
e1 = torch.einsum('ik, jl', torch.eye(A_cols, A_rows), torch.t(AXW0W1))
除了打字错误,这不是关于 A
的梯度的计算方式,并且当 A
和 AX
具有不同的大小时它会失败,否则对前传。对于 e1
,这应该是一个矩阵乘法,也许这是你的意图,在这种情况下 torch.einsum
应该是 'ik, kl'
,但这只是执行矩阵乘法的一种过于复杂的方法使用 torch.mm
更简单、更高效。 e2
不涉及针对 A
执行的任何计算,因此它不是梯度的一部分。
def torch_grad(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
grad_AX = torch.mm(torch.eye(rows, cols), W0W1.t())
grad_A = torch.mm(grad_AX, X.t())
return grad_A
# Autograd version to verify that the gradients are correct
def torch_autograd(A, W0, W1, X):
# Forward
W0W1 = torch.mm(W0, W1)
AX = torch.mm(A, X)
AXW0W1 = torch.mm(AX, W0W1)
XW0W1 = torch.mm(X, W0W1)
# Backward / Gradients
rows, cols = AXW0W1.size()
AXW0W1.backward(torch.eye(rows, cols))
return A.grad
# requires_grad=True for the autograd version to track
# gradients with respect to A
A = torch.randn(3, 4, requires_grad=True)
X = torch.randn(4, 5)
W0 = torch.randn(5, 6)
W1 = torch.randn(6, 5)
grad_result = torch_grad(A, W0, W1, X)
autograd_result = torch_autograd(A, W0, W1, X)
torch.equal(grad_result, autograd_result) # => True