Pytorch 在给定条件的情况下计算特定行的二维张量的平均值

Pytorch compute the mean of a 2D tensor at specific rows given a condition

假设我有一个张量

tensor([[0, 1, 2, 2],
        [2, 4, 2, 4],
        [3, 4, 3, 1],
        [4, 4, 4, 3]])

和指数张量

tensor([[1],
        [2],
        [1],
        [3]])

我想计算索引值匹配的平均值。在这种情况下,我想要第 1 行和第 3 行的平均值,因此最终输出将是

tensor([[1.5, 2.5, 2.5, 1.5],
        [2,   4,   2,   4],
        [4,   4,   4,   3]])

您可以使用 torch.scatter_reduce 来计算总和。要计算平均值,我们必须使用它两次,一次用于计算和,一次用于计算被加数,这样我们就可以除以计数的数量。不过有一个细节是,由于 pytorch 使用基于 0 的索引,我们需要从这些值中减去 1:

import torch
a = torch.tensor([[0, 1, 2, 2], [2, 4, 2, 4], [3, 4, 3, 1], [4, 4, 4, 3]])
b = torch.tensor([[1], [2], [1], [3]])
cc = torch.tensor([[1.5, 5.2, 5.2, 1.5], [2,   4,   2,   4], [4,   4,   4,   3]]) # goal

c = torch.scatter_reduce(
    a.to(float),
    0,
    torch.broadcast_to(b, a.shape) - 1,
    reduce='mean'
)
print(c)