向量化 for 循环 - 需要平均不同大小的切片

Vectorize for-loop - need to average slices of varying size

我正在尝试对子词嵌入进行平均以形成词级表示。每个词都有对应的起始索引和结束索引,表示该词由哪些子词组成。

sequence_output是B * 3 * 2的张量,其中3是最大序列长度,2是特征个数。

all_token_mapping是一个B * 3 * 2的张量,包含起始索引和结束索引。

initial_reps是num_nodes*2的张量,num_nodes是所有字数(不是子字)的总和在不同的样本中。

sequence_output = torch.arange(2*3*2).float().reshape(2, 3, 2)
tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]]])
all_token_mapping = torch.tensor([[[0,0],[1,2],[-1,-1]], [[0,2],[-1,-1],[-1,-1]]])
tensor([[[ 0,  0],
         [ 1,  2],
         [-1, -1]],

        [[ 0,  2],
         [-1, -1],
         [-1, -1]]])
num_nodes = 0
for sample in all_token_mapping:
  for mapping in sample:
    if mapping[0] != -1:
      num_nodes += 1
3
initial_reps = torch.empty((num_nodes, 2), dtype=torch.float32)
current_idx = 0
for i, feature_tokens_mapping in enumerate(all_token_mapping):
    for j, token_mapping in enumerate(feature_tokens_mapping):
        if token_mapping[0] == -1: # reached the end for this particular sequence
            break
        initial_reps[current_idx] = torch.mean(sequence_output[i][token_mapping[0]:token_mapping[-1] + 1], 0, keepdim=True)                                           
        current_idx += 1
initial_reps
tensor([[0., 1.],
        [3., 4.],
        [8., 9.]])

在上面的示例中,initial_reps[0] 将是 sequence_output[0][0:1] 的平均值,initial_reps[1] 将是 sequence_output[0][0:1] 的平均值sequence_output[0][1:3],initial_reps[2] 将是 sequence_output[1][0:3].

的平均值

我当前的代码将创建一个长度为 num_nodes 的空张量,for 循环将通过检查 token_mapping[0] 和 token_mapping[ 来计算每个索引处的值1] 对 sequence_output 的正确切片进行平均。

有没有办法向量化此代码?

此外,我有一个列表,其中包含每个样本的单词数。即列表中所有元素的总和 == num_nodes

我会像我在下面的代码中展示的那样做 smth。这个案例非常简单,所以我可以给你看一个输入和输出的例子。但这个概念可以扩展到任何数组大小或维度。我将“768”更改为我设置为 5 的变量 'num_features'。并将源节点的数量从 384 减少到 4。

导入手电筒

B = 3
num_nodes0 = 4
num_nodes = 3
num_features = 5

sequence_output = torch.arange(B * num_nodes0 * num_features).float()
sequence_output = sequence_output.reshape(B, num_nodes0, num_features)

all_token_mapping = torch.randint(0, num_nodes0, (B, num_nodes))

idx0 = torch.arange(B).reshape(-1, 1).repeat(1, num_nodes).flatten().long()
idx1 = all_token_mapping.flatten().long()
initial_reps = sequence_output[idx0, idx1, :].reshape(B, num_nodes, num_features)
initial_reps = torch.mean(initial_reps, axis = 1)

输出:

sequence_output = 
tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.]],

        [[20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.]],

        [[40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.],
         [50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.]]])
all_token_mapping = 
tensor([[0, 2, 1],
        [2, 3, 0],
        [2, 0, 0]])
initial_reps = 
tensor([[ 5.0000,  6.0000,  7.0000,  8.0000,  9.0000],
        [28.3333, 29.3333, 30.3333, 31.3333, 32.3333],
        [43.3333, 44.3333, 45.3333, 46.3333, 47.3333]])

https://discuss.pytorch.org/t/vectorize-for-loop-need-to-average-slices-of-varying-size/122618/2

的某人的帮助下,我找到了一个方法
initial_reps_list = []
for i, sample_output in enumerate(sequence_output):
    token_mapping = all_token_mapping[i]
    token_mapping = token_mapping[token_mapping != -1]
    non_padded_outputs = sample_output[:num_bert_tokens[i]]
    initial_reps_list.append(torch_scatter.segment_coo(non_padded_outputs, token_mapping, reduce="mean"))

initial_reps = torch.cat(initial_reps_list)

token_mapping 是按升序排列的索引列表,直到最大序列长度,并用 -1 填充。我循环遍历批次,对于每个样本,我得到标记映射,并且只保留非负索引。

num_bert_tokens 是一个列表,其中包含每个样本的标记数(无填充)。我得到了非填充输出,使用 segment_coo 根据 token_mapping 减少它们,并将它们全部附加到列表中。

循环后,我将列表中的所有张量连接在一起。

方法 segment_coo 将 src 张量中的所有值沿索引的最后一个维度减少到索引张量中指定的索引处。可以在以下位置找到更多详细信息:https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html

现在运行速度快多了!

我遇到了类似的问题,发现这里可以使用累计和:

>>> import torch
>>> hidden_states = torch.arange(4 * 5 * 2).reshape(4, 5, 2)
>>> target_segments_starts = torch.tensor([[0, 3, 0], [1, 0, 0], [1, 3, 0], [0, 1, 3]])
>>> target_segments_ends = torch.tensor([[2, 5, 0], [4, 0, 0], [3, 5, 0], [1, 3, 5]])

请注意,我使用另一种方式来定义开始和结束索引。此外,我使用 0 进行填充,因为我无法将 -1 传递给 torch.gather

>>> starts_reshaped = target_segments_starts.unsqueeze(-1).repeat(1, 1, hidden_states.shape[-1])  # just for torch.gather
>>> ends_reshaped = target_segments_ends.unsqueeze(-1).repeat(1, 1, hidden_states.shape[-1])  # just for torch.gather
>>> ends_reshaped_included = torch.where(ends_reshaped == 0, 0, ends_reshaped - 1)
>>> starts_reshaped_excluded = torch.where(starts_reshaped == 0, 0, starts_reshaped - 1)
>>> hidden_states_cumsum = hidden_states.cumsum(dim=1)
>>> starts_sums_excluded = torch.gather(hidden_states_cumsum, dim=1, index=starts_reshaped_excluded)
>>> ends_sums_included = torch.gather(hidden_states_cumsum, dim=1, index=ends_reshaped_included)

对于起始 ID = 0 的第一个段,我们需要将 cumsum 设置为 0,因为我们需要排除 ID 的 cumsum,因此 0 的排除 ID 为 -1,cumsum 为 0:

>>> starts_sums_excluded[:, 0][starts_reshaped[:, 0] == 0] = 0.0
>>> hidden_states_sum = ends_sums_included - starts_sums_excluded
>>> segment_lengths = ends_reshaped - starts_reshaped

-1 = 任意数,因为 hidden_states_sum 在需要的地方已经是 0,因此进一步除法将得到 0:

>>> segment_lengths[segment_lengths == 0] = -1

结果:

>>> hidden_states
tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5],
         [ 6,  7],
         [ 8,  9]],

        [[10, 11],
         [12, 13],
         [14, 15],
         [16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23],
         [24, 25],
         [26, 27],
         [28, 29]],

        [[30, 31],
         [32, 33],
         [34, 35],
         [36, 37],
         [38, 39]]])
>>> hidden_states_sum
tensor([[[ 2,  4],
         [14, 16],
         [ 0,  0]],

        [[42, 45],
         [ 0,  0],
         [ 0,  0]],

        [[46, 48],
         [54, 56],
         [ 0,  0]],

        [[30, 31],
         [66, 68],
         [74, 76]]])
>>> hidden_states_sum / segment_lengths
tensor([[[ 1.,  2.],
         [ 7.,  8.],
         [-0., -0.]],

        [[14., 15.],
         [-0., -0.],
         [-0., -0.]],

        [[23., 24.],
         [27., 28.],
         [-0., -0.]],

        [[30., 31.],
         [33., 34.],
         [37., 38.]]])

最后,您可以看到对于填充元素,我得到了零嵌入。您可以根据需要排除它们,并根据需要重塑结果张量。