如何对这些类型的张量使用掩码 select?

How to use masked select for these kind of tensors?

假设,我有一个张量a和一个张量b

import torch
a = torch.tensor([[[ 0.8856,  0.1411, -0.1856, -0.1425],
    [-0.0971,  0.1251,  0.1608, -0.1302],
    [-0.0901,  0.3215,  0.1763, -0.0412]], 
    [[ 0.8856,  0.1411, -0.1856, -0.1425],
    [-0.0971,  0.1251,  0.1608, -0.1302],
    [-0.0901,  0.3215,  0.1763, -0.0412]]])
b = torch.tensor([[0,
    2,
    1], 
    [0,
    2,
    1]])

现在,我想 select 来自张量 a 的索引,其中张量 b 的值不为 0。

pred_masks = ( b != 0 )
c = torch.masked_select( a, (pred_masks == 1))

当然,我得到了预期的错误。

----> 1 c = torch.masked_select( a, (pred_masks == 1))

RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 2

这是由包含 4 个项目的嵌套列表引起的。但是,需要select张量a中索引x处的嵌套列表的所有值,对应张量b中的索引x。

如有任何提示或回答,我将不胜感激。

我不确定你想要什么作为输出 c 的形状。由于您的掩码的形状为 (2,3) 而 a 的形状为 (2,3,4) 您是否希望输出形状为 (n,4) 的张量,其中 n 是在 ( 2,3)-面具?

如果是,那么我建议只使用掩码作为前两个维度的索引。

c = a[pred_masks,:]

希望对您有所帮助。