在 PyTorch 中批量 index_fill
Batched index_fill in PyTorch
我有一个大小为 (2, 3)
:
的索引张量
>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
[3., 4., 7.]])
和大小为(2, 8)
的值张量:
>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
我想通过沿 dim=-1
的索引将 value
中的元素设置为 1
。** 输出应如下所示:
>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
我尝试了 value[range(2), index] = 1
但它触发了一个错误。我也试过 torch.index_fill
但它不接受批量索引。 torch.scatter
需要额外创建一个大小为 2*8
且充满 1
的张量,这会消耗不必要的内存和时间。
您实际上可以通过设置 value
(int)选项而不是 src
选项(张量).
>>> value.scatter_(dim=-1, index=index.long(), value=1)
>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
确保 index
是 int64 类型。
我有一个大小为 (2, 3)
:
>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
[3., 4., 7.]])
和大小为(2, 8)
的值张量:
>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]])
我想通过沿 dim=-1
的索引将 value
中的元素设置为 1
。** 输出应如下所示:
>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
我尝试了 value[range(2), index] = 1
但它触发了一个错误。我也试过 torch.index_fill
但它不接受批量索引。 torch.scatter
需要额外创建一个大小为 2*8
且充满 1
的张量,这会消耗不必要的内存和时间。
您实际上可以通过设置 value
(int)选项而不是 src
选项(张量).
>>> value.scatter_(dim=-1, index=index.long(), value=1)
>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
[0., 0., 0., 1., 1., 0., 0., 1.]])
确保 index
是 int64 类型。