torch.where() 可以用于等效的广播形式吗?

Can torch.where() used in a equivalent broadcsating form?

我的代码中有以下 for 循环片段。嵌套循环正在减慢我的完整执行速度。

for q in range(batchSize):
    temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
    if len(temp)==0:
        output[q]=0
    else:
        output[q]=int(temp[0])

这里,composition_matrix[14000,2]维的pytorch张量,只有正整数作为单元格值。 predoutput 都是 [batchSize,2] 维火炬张量。 由于这个 for 循环大大减慢了我的代码速度,我无法获得与此代码段等效的广播解决方案。

是否存在消除此 for 循环的广播解决方案?

如有任何帮助,我将不胜感激。

一个最小可重现的例子是

import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])

for q in range(batchSize):
    temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
    if len(temp)==0:
        output[q]=0
    else:
        output[q]=int(temp[0])

对于简短且可能过于简化的示例,我们深表歉意。我担心更大的会更难想象。但我希望这适合你的目的。 我的解决方案可能看起来有点复杂,但它是完全矢量化的并且不包含显式循环。 这是我会做的:

import torch

torch.manual_seed(0)
batchSize = 8

pred               = torch.randint(0, 10, (batchSize, 2))
output             = torch.zeros((batchSize, 2))
composition_matrix = torch.randint(0, 10, (14, 2))

# compair all vectors in composition_matrix to all vectors in pred
comparisons = (composition_matrix.unsqueeze(0) == pred.unsqueeze(1))
comparisons = comparisons.all(2)

# form an index array the shape of the comparisons array
comparison_idxs = torch.arange(comparisons.shape[1])
comparison_idxs = comparison_idxs.repeat(batchSize).reshape(*comparisons.shape)

# multipy the comparisons array by the index array 
where_result = (comparison_idxs*comparisons)

# replace invalind zeros with the maximal value in each sample
batch_idxs   = torch.arange(comparisons.shape[0])
batch_idxs   = batch_idxs.repeat(comparisons.shape[1])
batch_idxs   = batch_idxs.reshape(comparisons.shape[1], comparisons.shape[0]).T
maxima       = where_result.max(1).values[batch_idxs]
maxima_vecor = maxima[(1-comparisons.int()).bool()]
where_result[(1-comparisons.int()).bool()] = maxima_vecor

vectorized_output = where_result.min(1)[0]

output = torch.zeros([batchSize])
for q in range(batchSize):
    temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
    if len(temp)==0:
        output[q]=0
    else:
        output[q]=int(temp[0])

输出:

composition_matrix = 
tensor([[6, 8],
        [4, 3],
        [6, 9],
        [1, 4],
        [4, 1],
        [9, 9],
        [9, 0],
        [1, 2],
        [3, 0],
        [5, 5],
        [2, 9],
        [1, 8],
        [8, 3],
        [6, 9]])
pred = 
tensor([[4, 9],
        [3, 0],
        [3, 9],
        [7, 3],
        [7, 3],
        [1, 6],
        [6, 9],
        [8, 6]])
output = 
tensor([0., 8., 0., 0., 0., 0., 2., 0.])
vectorized_output = 
tensor([0, 8, 0, 0, 0, 0, 2, 0])

一些计时结果:

torch.manual_seed(0)
batchSize = 8
pred               = torch.randint(0, 10, (batchSize, 2))
composition_matrix = torch.randint(0, 10, (14000, 2))


print('timing the vectorized_solution:')
%timeit -n 1000 vectorized_solution(composition_matrix, pred,)

print('timing the loop_solution:')
%timeit -n 1000 loop_solution(composition_matrix, pred,)

输出:

timing the vectorized_solution:
1000 loops, best of 5: 137 µs per loop
timing the loop_solution:
1000 loops, best of 5: 1.89 ms per loop

为简单起见,您首先需要了解操作的本质。你有两个张量。张量 A 的形状为 (14000, 2),张量 B 的形状为 (64, 2)。你要做的操作是:

For each row B[i] in B, compare that B[i] (of shape (2,) with A (of shape (14000, 2)). If B[i] occurs within A, set output[i] = index of first occurrence.

这实际上可以用两行代码(甚至一行)完成:

comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
  • 第一行创建 compcomposition_matrixpred 的广播比较,一个形状为 (14000, 64).[=18 的布尔张量=]

  • 第二行需要找到“第一个匹配项的索引”。这可以通过 argmax 非常简单地完成:它将 return 第一个“1”的索引(或者如果所有值都是“0”,将 return 第一个索引,即 0)。

(请注意,torch 不支持“bool”张量的 argmax,因此需要将 comp 转换为另一种数据类型。)