torch.where() 可以用于等效的广播形式吗?
Can torch.where() used in a equivalent broadcsating form?
我的代码中有以下 for 循环片段。嵌套循环正在减慢我的完整执行速度。
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
这里,composition_matrix
是[14000,2]
维的pytorch张量,只有正整数作为单元格值。 pred
和 output
都是 [batchSize,2]
维火炬张量。
由于这个 for 循环大大减慢了我的代码速度,我无法获得与此代码段等效的广播解决方案。
是否存在消除此 for 循环的广播解决方案?
如有任何帮助,我将不胜感激。
一个最小可重现的例子是
import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
对于简短且可能过于简化的示例,我们深表歉意。我担心更大的会更难想象。但我希望这适合你的目的。
我的解决方案可能看起来有点复杂,但它是完全矢量化的并且不包含显式循环。
这是我会做的:
import torch
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
output = torch.zeros((batchSize, 2))
composition_matrix = torch.randint(0, 10, (14, 2))
# compair all vectors in composition_matrix to all vectors in pred
comparisons = (composition_matrix.unsqueeze(0) == pred.unsqueeze(1))
comparisons = comparisons.all(2)
# form an index array the shape of the comparisons array
comparison_idxs = torch.arange(comparisons.shape[1])
comparison_idxs = comparison_idxs.repeat(batchSize).reshape(*comparisons.shape)
# multipy the comparisons array by the index array
where_result = (comparison_idxs*comparisons)
# replace invalind zeros with the maximal value in each sample
batch_idxs = torch.arange(comparisons.shape[0])
batch_idxs = batch_idxs.repeat(comparisons.shape[1])
batch_idxs = batch_idxs.reshape(comparisons.shape[1], comparisons.shape[0]).T
maxima = where_result.max(1).values[batch_idxs]
maxima_vecor = maxima[(1-comparisons.int()).bool()]
where_result[(1-comparisons.int()).bool()] = maxima_vecor
vectorized_output = where_result.min(1)[0]
output = torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
输出:
composition_matrix =
tensor([[6, 8],
[4, 3],
[6, 9],
[1, 4],
[4, 1],
[9, 9],
[9, 0],
[1, 2],
[3, 0],
[5, 5],
[2, 9],
[1, 8],
[8, 3],
[6, 9]])
pred =
tensor([[4, 9],
[3, 0],
[3, 9],
[7, 3],
[7, 3],
[1, 6],
[6, 9],
[8, 6]])
output =
tensor([0., 8., 0., 0., 0., 0., 2., 0.])
vectorized_output =
tensor([0, 8, 0, 0, 0, 0, 2, 0])
一些计时结果:
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
composition_matrix = torch.randint(0, 10, (14000, 2))
print('timing the vectorized_solution:')
%timeit -n 1000 vectorized_solution(composition_matrix, pred,)
print('timing the loop_solution:')
%timeit -n 1000 loop_solution(composition_matrix, pred,)
输出:
timing the vectorized_solution:
1000 loops, best of 5: 137 µs per loop
timing the loop_solution:
1000 loops, best of 5: 1.89 ms per loop
为简单起见,您首先需要了解操作的本质。你有两个张量。张量 A 的形状为 (14000, 2)
,张量 B 的形状为 (64, 2)
。你要做的操作是:
For each row B[i] in B, compare that B[i] (of shape (2,) with A (of
shape (14000, 2)). If B[i] occurs within A, set output[i] = index of
first occurrence.
这实际上可以用两行代码(甚至一行)完成:
comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
第一行创建 comp
,composition_matrix
和 pred
的广播比较,一个形状为 (14000, 64)
.[=18 的布尔张量=]
第二行需要找到“第一个匹配项的索引”。这可以通过 argmax 非常简单地完成:它将 return 第一个“1”的索引(或者如果所有值都是“0”,将 return 第一个索引,即 0)。
(请注意,torch 不支持“bool”张量的 argmax,因此需要将 comp 转换为另一种数据类型。)
我的代码中有以下 for 循环片段。嵌套循环正在减慢我的完整执行速度。
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
这里,composition_matrix
是[14000,2]
维的pytorch张量,只有正整数作为单元格值。 pred
和 output
都是 [batchSize,2]
维火炬张量。
由于这个 for 循环大大减慢了我的代码速度,我无法获得与此代码段等效的广播解决方案。
是否存在消除此 for 循环的广播解决方案?
如有任何帮助,我将不胜感激。
一个最小可重现的例子是
import torch
composition_matrix=torch.randint(3, 10, (14000,2))
batchSize=64
pred=torch.randint(3, 10, (batchSize,2))
output=torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
对于简短且可能过于简化的示例,我们深表歉意。我担心更大的会更难想象。但我希望这适合你的目的。 我的解决方案可能看起来有点复杂,但它是完全矢量化的并且不包含显式循环。 这是我会做的:
import torch
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
output = torch.zeros((batchSize, 2))
composition_matrix = torch.randint(0, 10, (14, 2))
# compair all vectors in composition_matrix to all vectors in pred
comparisons = (composition_matrix.unsqueeze(0) == pred.unsqueeze(1))
comparisons = comparisons.all(2)
# form an index array the shape of the comparisons array
comparison_idxs = torch.arange(comparisons.shape[1])
comparison_idxs = comparison_idxs.repeat(batchSize).reshape(*comparisons.shape)
# multipy the comparisons array by the index array
where_result = (comparison_idxs*comparisons)
# replace invalind zeros with the maximal value in each sample
batch_idxs = torch.arange(comparisons.shape[0])
batch_idxs = batch_idxs.repeat(comparisons.shape[1])
batch_idxs = batch_idxs.reshape(comparisons.shape[1], comparisons.shape[0]).T
maxima = where_result.max(1).values[batch_idxs]
maxima_vecor = maxima[(1-comparisons.int()).bool()]
where_result[(1-comparisons.int()).bool()] = maxima_vecor
vectorized_output = where_result.min(1)[0]
output = torch.zeros([batchSize])
for q in range(batchSize):
temp=torch.where((composition_matrix == pred[q]).all(dim=1))[0]
if len(temp)==0:
output[q]=0
else:
output[q]=int(temp[0])
输出:
composition_matrix =
tensor([[6, 8],
[4, 3],
[6, 9],
[1, 4],
[4, 1],
[9, 9],
[9, 0],
[1, 2],
[3, 0],
[5, 5],
[2, 9],
[1, 8],
[8, 3],
[6, 9]])
pred =
tensor([[4, 9],
[3, 0],
[3, 9],
[7, 3],
[7, 3],
[1, 6],
[6, 9],
[8, 6]])
output =
tensor([0., 8., 0., 0., 0., 0., 2., 0.])
vectorized_output =
tensor([0, 8, 0, 0, 0, 0, 2, 0])
一些计时结果:
torch.manual_seed(0)
batchSize = 8
pred = torch.randint(0, 10, (batchSize, 2))
composition_matrix = torch.randint(0, 10, (14000, 2))
print('timing the vectorized_solution:')
%timeit -n 1000 vectorized_solution(composition_matrix, pred,)
print('timing the loop_solution:')
%timeit -n 1000 loop_solution(composition_matrix, pred,)
输出:
timing the vectorized_solution:
1000 loops, best of 5: 137 µs per loop
timing the loop_solution:
1000 loops, best of 5: 1.89 ms per loop
为简单起见,您首先需要了解操作的本质。你有两个张量。张量 A 的形状为 (14000, 2)
,张量 B 的形状为 (64, 2)
。你要做的操作是:
For each row B[i] in B, compare that B[i] (of shape (2,) with A (of shape (14000, 2)). If B[i] occurs within A, set output[i] = index of first occurrence.
这实际上可以用两行代码(甚至一行)完成:
comp = (composition_matrix[:, None, :] == pred).all(dim=-1)
output = torch.argmax(comp.float(), axis=0)
第一行创建
comp
,composition_matrix
和pred
的广播比较,一个形状为(14000, 64)
.[=18 的布尔张量=]第二行需要找到“第一个匹配项的索引”。这可以通过 argmax 非常简单地完成:它将 return 第一个“1”的索引(或者如果所有值都是“0”,将 return 第一个索引,即 0)。
(请注意,torch 不支持“bool”张量的 argmax,因此需要将 comp 转换为另一种数据类型。)