将 [3, 2, 3] 乘以 pytorch 中的 [3, 2] 张量(沿维度的点积)

Multiply a [3, 2, 3] by a [3, 2] tensor in pytorch (dot product along dimension)

给定以下张量 xy,形状为 [3,2,3][3,2]。我想沿二维乘以张量,这应该是一种点积并沿轴缩放 return [3,2,3] 张量。

import torch
a  = [[[0.2,0.3,0.5],[-0.5,0.02,1.0]],[[0.01,0.13,0.06],[0.35,0.12,0.0]], [[1.0,-0.3,1.0],[1.0,0.02, 0.03]] ]
b = [[1,2],[1,3],[0,2]]
x = torch.FloatTensor(a) # shape [3,2,3]
y = torch.FloatTensor(b) # shape [3,2]

预期输出:

Expected output shape should be [3,2,3]
#output = [[[0.2,0.3,0.5],[-1.0,0.04,2.0]],[[0.01,0.13,0.06],[1.05,0.36,0.0]], [[0.0,0.0,0.0],[2.0,0.04, 0.06]] ]

我已经尝试了下面的两个,但是 none 给出了所需的输出和输出形状。

torch.matmul(x,y)
torch.matmul(x,y.unsqueeze(1).shape)

解决此问题的最佳方法是什么?

这只是 broadcasted 相乘。所以可以在y的末尾插入一个幺正维度,使其成为[3,2,1]张量,然后乘以x。有多种插入单一维度的方法。

# all equivalent
x * y.unsqueeze(2)
x * y[..., None]
x * y[:, :, None]
x * y.reshape(3, 2, 1)

您也可以使用 torch.einsum.

torch.einsum('abc,ab->abc', x, y)