使用特定值创建 pytorch 张量二进制掩码
Creating a pytorch tensor binary mask using specific values
我得到一个带有整数的 pytorch 二维张量,以及始终出现在张量的每一行中的 2 个整数。
我想创建一个二进制掩码,它将在这 2 个整数的 two 出现之间包含 1,否则为 0。例如,如果整数是 4 和 2,并且一维数组是[1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4]
,返回的掩码将是:[0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0].
有没有无需迭代即可高效快速地计算此掩码的方法?
可能有点乱,但它无需迭代即可工作。在下文中,我假设了一个示例张量 m
,我将其应用于解决方案,用它来解释比使用一般符号更容易。
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5.],
[4., 7., 2., 1., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int()
print(final_mask)
我们得到
tensor([[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1]], dtype=torch.int32)
关于mask_0
和mask_1
这两个掩码,如果不清楚它们是什么,请注意nz_indexes[:,0]
containts,对于m
的每一行,列索引在其中找到 vals[0]
,并且 nz_indexes[:,1]
同样包含对于 m
的每一行,在其中找到 vals[1]
的列索引。
完全基于以前的解决方案,这里是修改后的解决方案:
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
[4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()
print(final_mask)
我们终于得到了
tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
我得到一个带有整数的 pytorch 二维张量,以及始终出现在张量的每一行中的 2 个整数。
我想创建一个二进制掩码,它将在这 2 个整数的 two 出现之间包含 1,否则为 0。例如,如果整数是 4 和 2,并且一维数组是[1,1,9,4,6,5,1,2,9,9,11,4,3,6,5,2,3,4]
,返回的掩码将是:[0,0,0,1,1,1,1,1,0,0,0,1,1,1,1,1,0,0,0].
有没有无需迭代即可高效快速地计算此掩码的方法?
可能有点乱,但它无需迭代即可工作。在下文中,我假设了一个示例张量 m
,我将其应用于解决方案,用它来解释比使用一般符号更容易。
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5.],
[4., 7., 2., 1., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],2)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int()
print(final_mask)
我们得到
tensor([[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1]], dtype=torch.int32)
关于mask_0
和mask_1
这两个掩码,如果不清楚它们是什么,请注意nz_indexes[:,0]
containts,对于m
的每一行,列索引在其中找到 vals[0]
,并且 nz_indexes[:,1]
同样包含对于 m
的每一行,在其中找到 vals[1]
的列索引。
完全基于以前的解决方案,这里是修改后的解决方案:
import torch
vals=[2,8]#let's assume those are the constant values that appear in each row
#target tensor
m=torch.tensor([[1., 2., 7., 8., 5., 2., 6., 5., 8., 4.],
[4., 7., 2., 1., 8., 2., 6., 5., 6., 8.]])
#let's find the indexes of those values
k=m==vals[0]
p=m==vals[1]
v=(k.int()+p.int()).bool()
nz_indexes=v.nonzero()[:,1].reshape(m.shape[0],4)
#let's create a tiling of the indexes
q=torch.arange(m.shape[1])
q=q.repeat(m.shape[0],1)
#you only need two masks, no matter the size of m. see explanation below
msk_0=(nz_indexes[:,0].repeat(m.shape[1],1).transpose(0,1))<=q
msk_1=(nz_indexes[:,1].repeat(m.shape[1],1).transpose(0,1))>=q
msk_2=(nz_indexes[:,2].repeat(m.shape[1],1).transpose(0,1))<=q
msk_3=(nz_indexes[:,3].repeat(m.shape[1],1).transpose(0,1))>=q
final_mask=msk_0.int() * msk_1.int() + msk_2.int() * msk_3.int()
print(final_mask)
我们终于得到了
tensor([[0, 1, 1, 1, 0, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)