在pytorch中有选择地替换张量向量的有效方法

Efficient way of selectively replacing vectors from a tensor in pytorch

给定一批文本序列,将其转换为张量,每个单词使用词嵌入或向量(300 维)表示。我需要用一组新的嵌入有选择地替换某些特定单词的向量。此外,这种替换不会只针对特定单词的所有出现而发生,而是随机发生。目前,我有以下代码来实现这一点。它使用 2 个 for 循环遍历每个单词,检查单词是否在指定列表中,splIndices。然后根据 selected_.

中的 T 或 F 值检查单词是否需要替换

但是可以用更有效的方式做到这一点吗?

下面的代码可能不是MWE,但我试图通过删除细节来简化代码,以便专注于问题。请忽略代码的语义或目的,因为它可能没有在此代码段中适当地表示。问题是关于提高性能。


splIndices = [45, 62, 2983, 456, 762]  # vocabulary indices which needs to be replaced
splFreqs = 2000  # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0, 1) > 0.2  # Tensor with 20% of the entries True
replIndexCtr = 0  # counter for selected_

# Dictionary with vectors to be replaced. This is a dummy function.
# Original function depends on some property of the word
diffVector = {45: Torch.Tensor(300).uniform_(0, 1), ...... 762: Torch.Tensor(300).uniform_(0, 1) } 

embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x  # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # shape [32, 41, 300] - the sequence now has replaced vocab indices with embeddings

# iterate through batch for sequences
for i, item in enumerate(x):
    # iterate sequences for words
    for j, stuff in enumerate(item):
        if tempVals[i][j].item() in splIndices: 
            if self.selected_[replIndexCtr] == True:                   
                x[i,j] = diffVector[tempVals[i][j].item()]
                replIndexCtr += 1


可以通过以下方式对其进行矢量化:

import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size, sentence_size, vocab_size, emb_size = 3, 2, 15, 1

# Make certain bias as a marker of embedding 
embedder_1 = nn.Linear(vocab_size, emb_size)
embedder_1.weight.data.fill_(0)
embedder_1.bias.data.fill_(200)

embedder_2 = nn.Linear(vocab_size, emb_size)
embedder_2.weight.data.fill_(0)
embedder_2.bias.data.fill_(404)

# Here are the indices of words which need different embdedding
replace_list = [3, 5, 7, 9] 

# Make a binary mask highlighing special words' indices
mask = torch.zeros(batch_size, sentence_size, vocab_size)
mask[..., replace_list] = 1

# Make random dataset
data_indices = torch.randint(0, vocab_size, (batch_size, sentence_size))
data_onehot = F.one_hot(data_indices, vocab_size)

# Check if onehot of a word collides with replace mask 
replace_mask = mask.long() * data_onehot
replace_mask = torch.sum(replace_mask, dim=-1).byte() # byte() is critical here

data_emb = torch.empty(batch_size, sentence_size, emb_size)

# Fill default embeddings
data_emb[1-replace_mask] = embedder_1(data_onehot[1-replace_mask].float())
if torch.max(replace_mask) != 0: # If not all zeros
    # Fill special embeddings
    data_emb[replace_mask] = embedder_2(data_onehot[replace_mask].float())

print(data_indices)
print(replace_mask)
print(data_emb.squeeze(-1).int())

这是一个可能的输出示例:

# Word indices
tensor([[ 6,  9],
        [ 5, 10],
        [ 4, 11]])
# Embedding replacement mask
tensor([[0, 1],
        [1, 0],
        [0, 0]], dtype=torch.uint8)
# Resulting replacement
tensor([[200, 404],
        [404, 200],
        [200, 200]], dtype=torch.int32)