Torch - 用另一个矩阵查询矩阵

Torch - Query matrix with another matrix

我有一个 m x n 张量(张量 1)和另一个 k x 2 张量(张量 2),我希望使用基于张量 2 的索引提取张量 1 的所有值。例如;

Tensor1
  1   2   3   4   5
  6   7   8   9  10
 11  12  13  14  15
 16  17  18  19  20
[torch.DoubleTensor of size 4x5]

Tensor2
 2  1
 3  5
 1  1
 4  3
[torch.DoubleTensor of size 4x2]

函数会产生;

6
15
1
18

想到的第一个解决方案是简单地遍历索引并选择相应的值:

function get_elems_simple(tensor, indices)
    local res = torch.Tensor(indices:size(1)):typeAs(tensor)
    local i = 0
    res:apply(
        function () 
            i = i + 1
            return tensor[indices[i]:clone():storage()] 
        end)
    return res
end

这里tensor[indices[i]:clone():storage()]只是从多维张量中选取元素的通用方法。在 k 维情况下,这完全类似于 tensor[{indices[i][1], ... , indices[i][k]}].

如果您不必提取大量值,此方法效果很好(瓶颈是 :apply 方法,该方法无法使用许多优化技术和 SIMD 指令,因为它执行的函数是黑色的盒子)。这项工作可以更有效地完成:方法 :index 完全可以满足您的需求……使用一维张量。多维target/index张量需要展平:

function flatten_indices(sp_indices, shape)
    sp_indices = sp_indices - 1
    local n_elem, n_dim = sp_indices:size(1), sp_indices:size(2)
    local flat_ind = torch.LongTensor(n_elem):fill(1)

    local mult = 1
    for d = n_dim, 1, -1 do
        flat_ind:add(sp_indices[{{}, d}] * mult)
        mult = mult * shape[d]
    end
    return flat_ind
end

function get_elems_efficient(tensor, sp_indices)
    local flat_indices = flatten_indices(sp_indices, tensor:size()) 
    local flat_tensor = tensor:view(-1)
    return flat_tensor:index(1, flat_indices)
end

差别很大:

n = 500000
k = 100
a = torch.rand(n, k)
ind = torch.LongTensor(n, 2)
ind[{{}, 1}]:random(1, n)
ind[{{}, 2}]:random(1, k)

elems1 = get_elems_simple(a, ind)      # 4.53 sec
elems2 = get_elems_efficient(a, ind)   # 0.05 sec

print(torch.all(elems1:eq(elems2)))    # true