PyTorch Batch 屏蔽 select 实现
PyTorch Batch masked select implementation
如何执行 批处理 masked_select
?
鉴于:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
所需的输出将是:
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这是一个可能的方法:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
masks = torch.tensor([[ True, False, False, False, True],
[ True, False, True, True, False]])
for i in range(x.size(0)):
mask = masks[i]
s = torch.masked_select(x[i], mask)
ones[i][:s.size(0)] = s
是否有其他解决方案?
这类问题的主要问题是中间结果是非同质的:在您的批次中,元素将具有不同数量的掩码值。如果我们想应用 PyTorch 内置函数,这是一个问题。这里我提供两种解决方案来执行此操作。
1- 使用 list comprehension
通过适当数量的批处理元素、掩码和填充:
>>> pad = lambda v: F.pad(v, [0, len(m)-len(v)], value=1)
>>> torch.stack([pad(r[m]) for r, m in zip(x, masks)])
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这很简单,与您的方法相似。
2- 使用 torch.scatter
向量化替代方案是构造正确的值和索引张量,以便我们可以应用 torch.scatter
并获得所需的结果。这里的技巧是使用扁平张量。从 x
和 masks
我们首先要访问 nz
和 idx
定义为:
nz
:来自 x
的非屏蔽值(当然使用 masks
), 即 我们需要找到:
tensor([1., 3., 1., 4., 3.])
idx
:它们在输出张量中的对应索引展平。
tensor([ 0, 1, 5, 6, 7])
然后我们可以应用类似 out = ones.scatter(dim=0, idx, nz)
的散点图,这将有效地执行:out[idx[i]] = nz[i]
.
要构造nz
,我们可以直接用masks
索引masks
非零值索引masks
:
>>> nz = x[masks]
tensor([1., 3., 1., 4., 3.])
对于idx
,这会有点棘手。我们可以对掩码本身进行排序,将其展平并使用 torch.Tensor.nonzero
获得非零值。排序后,True
值在每行的开头结束:
>>> idx = masks.sort(1, True).values.view(-1).nonzero()[:,0]
tensor([ 0, 1, 5, 6, 7])
最后我们可以应用 torch.scatter
并重塑以获得所需的结果:
>>> torch.ones(x.numel()).scatter(0, idx, nz).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
此处torch.scatter
的使用受到限制,因为输入是一维的。一个等效的方法是简单地:
>>> o = torch.ones(x.numel())
>>> o[idx] = nz
>>> o.view_as(x)
完整方法:
>>> idx = masks.sort(1, True)[0].view(-1).nonzero()[:,0]
>>> torch.ones(x.numel()).scatter(0, idx, x[masks]).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
如何执行 批处理 masked_select
?
鉴于:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
所需的输出将是:
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这是一个可能的方法:
x = torch.tensor([[1., 2., 2., 2., 3.],
[1., 2., 4., 3., 2.]])
ones = torch.tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
masks = torch.tensor([[ True, False, False, False, True],
[ True, False, True, True, False]])
for i in range(x.size(0)):
mask = masks[i]
s = torch.masked_select(x[i], mask)
ones[i][:s.size(0)] = s
是否有其他解决方案?
这类问题的主要问题是中间结果是非同质的:在您的批次中,元素将具有不同数量的掩码值。如果我们想应用 PyTorch 内置函数,这是一个问题。这里我提供两种解决方案来执行此操作。
1- 使用 list comprehension
通过适当数量的批处理元素、掩码和填充:
>>> pad = lambda v: F.pad(v, [0, len(m)-len(v)], value=1)
>>> torch.stack([pad(r[m]) for r, m in zip(x, masks)])
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
这很简单,与您的方法相似。
2- 使用 torch.scatter
向量化替代方案是构造正确的值和索引张量,以便我们可以应用 torch.scatter
并获得所需的结果。这里的技巧是使用扁平张量。从 x
和 masks
我们首先要访问 nz
和 idx
定义为:
nz
:来自x
的非屏蔽值(当然使用masks
), 即 我们需要找到:tensor([1., 3., 1., 4., 3.])
idx
:它们在输出张量中的对应索引展平。tensor([ 0, 1, 5, 6, 7])
然后我们可以应用类似 out = ones.scatter(dim=0, idx, nz)
的散点图,这将有效地执行:out[idx[i]] = nz[i]
.
要构造nz
,我们可以直接用masks
索引masks
非零值索引masks
:
>>> nz = x[masks]
tensor([1., 3., 1., 4., 3.])
对于idx
,这会有点棘手。我们可以对掩码本身进行排序,将其展平并使用 torch.Tensor.nonzero
获得非零值。排序后,True
值在每行的开头结束:
>>> idx = masks.sort(1, True).values.view(-1).nonzero()[:,0]
tensor([ 0, 1, 5, 6, 7])
最后我们可以应用 torch.scatter
并重塑以获得所需的结果:
>>> torch.ones(x.numel()).scatter(0, idx, nz).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])
此处torch.scatter
的使用受到限制,因为输入是一维的。一个等效的方法是简单地:
>>> o = torch.ones(x.numel())
>>> o[idx] = nz
>>> o.view_as(x)
完整方法:
>>> idx = masks.sort(1, True)[0].view(-1).nonzero()[:,0]
>>> torch.ones(x.numel()).scatter(0, idx, x[masks]).view_as(x)
tensor([[1., 3., 1., 1., 1.],
[1., 4., 3., 1., 1.]])