在pytorch中将矩阵的行乘以向量元素?

Multiply rows of matrix by vector elementwise in pytorch?

我想执行以下操作,但使用 PyTorch

下面的例子和描述来自这个post.

I have a numeric matrix with 25 columns and 23 rows, and a vector of length 25. How can I multiply each row of the matrix by the vector without using a for loop?

The result should be a 25x23 matrix (the same size as the input), but each row has been multiplied by the vector.

R 中的示例代码(来源:reproducible example from @hatmatrix's answer):

matrix <- matrix(rep(1:3,each=5),nrow=3,ncol=5,byrow=TRUE)

     [,1] [,2] [,3] [,4] [,5]
[1,]    1    1    1    1    1
[2,]    2    2    2    2    2
[3,]    3    3    3    3    3

vector <- 1:5

期望的输出:

     [,1] [,2] [,3] [,4] [,5]
[1,]    1    2    3    4    5
[2,]    2    4    6    8   10
[3,]    3    6    9   12   15

使用 Pytorch 执行此操作的最佳方法是什么?

答案太琐碎了,我忽略了。

为简单起见,我在此答案中使用了较小的向量和矩阵。

矩阵行乘以向量:

X = torch.tensor([[1,2,3],[5,6,7]])                                                                                                                                                                          
y = torch.tensor([7,4])                                                                                                                                                                                   
X.transpose(0,1)*y
# or alternatively
y*X.transpose(0,1)

输出:

tensor([[ 7, 20],
        [14, 24],
        [21, 28]])

tensor([[ 7, 20],
        [14, 24],
        [21, 28]])

用向量乘以矩阵的列:

要将矩阵的列乘以向量,您可以使用相同的运算符“*”,但无需先转置矩阵(或向量)

X = torch.tensor([[3, 5],[5, 5],[1, 0]])                                                                                                                                                                          
y = torch.tensor([7,4])                                                                                                                                                                                   
X*y
# or alternatively
y*X

输出:

tensor([[21, 20],
        [35, 20],
        [ 7,  0]])

tensor([[21, 20],
        [35, 20],
        [ 7,  0]])