PyTorch:具有相同输出 bin 索引的所有数据点的逐元素最大值

PyTorch: element-wise max over all data points with the same output bin index

我正在使用 PyTorch (1.8)。有没有一种聪明的方法可以对具有相同输出索引的所有数据点进行逐元素取最大值?

假设我有一个大小为 (N, M) 的数据张量和一个大小为 (N,) 的包含索引 [0, K] 的索引张量。 现在我想根据索引值将数据张量分箱成大小为 (K, M) 的张量,但是如果两个或更多数据点分箱到同一个槽中,那么我想保留元素方面的最大值。

我见过像下面这样的天真方法,但没有给出元素方面的最大值,而是只存储最后合并的内容。

data = torch.randn((N, M))
index = torch.randint(K, (N,))
output = torch.zeros((K, M))

output[index] = data

目前我正在实现自定义 cuda 内核来解决这个问题,但想知道这是否可以用标准 PyTorch 解决。

编辑:最小示例:

data = torch.tensor([[10,1],[9,2],[8,3],[7,4],[6,5]])
index = torch.tensor([2,1,0,1,2], dtype=torch.long)
# something happens
# expected output: 
# [[8, 3], [9, 4], [10, 5]]

PyTorch 似乎还没有这方面的本地实现,但有一个存储库可以做到这一点。 PyTorch Scatter

我所描述的似乎与 scatter_max 相对应。

from torch_scatter import scatter_max

scatter_max(data, index, dim=0)