如何对这些类型的张量使用掩码 select?
How to use masked select for these kind of tensors?
假设,我有一个张量a
和一个张量b
import torch
a = torch.tensor([[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]],
[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]]])
b = torch.tensor([[0,
2,
1],
[0,
2,
1]])
现在,我想 select 来自张量 a
的索引,其中张量 b
的值不为 0。
pred_masks = ( b != 0 )
c = torch.masked_select( a, (pred_masks == 1))
当然,我得到了预期的错误。
----> 1 c = torch.masked_select( a, (pred_masks == 1))
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2
这是由包含 4 个项目的嵌套列表引起的。但是,需要select张量a
中索引x处的嵌套列表的所有值,对应张量b
中的索引x。
如有任何提示或回答,我将不胜感激。
我不确定你想要什么作为输出 c 的形状。由于您的掩码的形状为 (2,3) 而 a 的形状为 (2,3,4) 您是否希望输出形状为 (n,4) 的张量,其中 n 是在 ( 2,3)-面具?
如果是,那么我建议只使用掩码作为前两个维度的索引。
c = a[pred_masks,:]
希望对您有所帮助。
假设,我有一个张量a
和一个张量b
import torch
a = torch.tensor([[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]],
[[ 0.8856, 0.1411, -0.1856, -0.1425],
[-0.0971, 0.1251, 0.1608, -0.1302],
[-0.0901, 0.3215, 0.1763, -0.0412]]])
b = torch.tensor([[0,
2,
1],
[0,
2,
1]])
现在,我想 select 来自张量 a
的索引,其中张量 b
的值不为 0。
pred_masks = ( b != 0 )
c = torch.masked_select( a, (pred_masks == 1))
当然,我得到了预期的错误。
----> 1 c = torch.masked_select( a, (pred_masks == 1))
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2
这是由包含 4 个项目的嵌套列表引起的。但是,需要select张量a
中索引x处的嵌套列表的所有值,对应张量b
中的索引x。
如有任何提示或回答,我将不胜感激。
我不确定你想要什么作为输出 c 的形状。由于您的掩码的形状为 (2,3) 而 a 的形状为 (2,3,4) 您是否希望输出形状为 (n,4) 的张量,其中 n 是在 ( 2,3)-面具?
如果是,那么我建议只使用掩码作为前两个维度的索引。
c = a[pred_masks,:]
希望对您有所帮助。