PyTorch 使用乘法按列减少

PyTorch Column-wise reduction using multiplication

我想通过将同一行中的所有值相乘来减少 Torch 张量中的列。 所以,例如:

x = torch.tensor([[1,1,1],[1,1,0],[1,1,2], [1,2,2]])

形状是4*3。减少后,我想要一个形状为 4 的张量,每个值都是每列的乘积,即。

x_reduced = torch.tensor([1,0,2,4])

有手电筒操作员可以轻松做到这一点吗?

是的,函数调用很简单:torch.prod(x,dim = 1).