在 PyTorch 中批量 index_fill

Batched index_fill in PyTorch

我有一个大小为 (2, 3):

的索引张量
>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
        [3., 4., 7.]])

和大小为(2, 8)的值张量:

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

我想通过沿 dim=-1 的索引将 value 中的元素设置为 1。** 输出应如下所示:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

我尝试了 value[range(2), index] = 1 但它触发了一个错误。我也试过 torch.index_fill 但它不接受批量索引。 torch.scatter 需要额外创建一个大小为 2*8 且充满 1 的张量,这会消耗不必要的内存和时间。

您实际上可以通过设置 valueint)选项而不是 src 选项(张量).

>>> value.scatter_(dim=-1, index=index.long(), value=1)

>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

确保 indexint64 类型。