Tensor[batch_mask, ...] 是做什么的?
What does Tensor[batch_mask, ...] do?
我在 BiLSTM 的实现中看到了这行代码:
batch_output = batch_output[batch_mask, ...]
我认为这是某种“屏蔽”操作,但在 Google 上找不到关于 ...
含义的信息。请帮忙:).
原码:
class BiLSTM(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained=None):
# irrelevant code ..........
def forward(self, batch_input, batch_input_lens, batch_mask):
batch_size, padding_length = batch_input.size()
batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim
batch_input = rnn_utils.pack_padded_sequence(
batch_input, batch_input_lens, batch_first=True)
batch_output, self.hidden = self.lstm(batch_input, self.hidden)
self.repackage_hidden(self.hidden)
batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
####### HERE ##########
batch_output = batch_output[batch_mask, ...]
#########################
out = self.hidden2tag(batch_output)
return out
此语句使用 batch_mask
包含的索引屏蔽了 batch_output
的第一个维度。实际上,这意味着您要从批处理中选择一些元素。
这是一个实际的例子:
>>> x = torch.rand(3,1,4,4)
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.5471, 0.5031, 0.3906, 0.7554],
[0.1895, 0.3985, 0.7083, 0.7849],
[0.3128, 0.6733, 0.9223, 0.5345],
[0.2689, 0.9876, 0.1092, 0.7405]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
有掩码和掩码操作:
>>> mask = [0, 2]
>>> x[mask]
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
仅保留索引 0
和 2
处的元素。
注意:x[mask]
与 x[mask, ...]
相同,其中省略号不是必需的,因为默认情况下所有定位的维度都将选择其所有索引。
我假设 batch_mask
是一个布尔张量。在这种情况下,batch_output[batch_mask]
执行布尔索引,选择对应于 batch_mask
.
中的 True
的元素
...
通常被称为 ellipsis,对于 PyTorch(还有其他 NumPy-like 库),它是 shorthand 以避免多次重复列运算符 (:
)。例如,给定 tensor
v
,其中 v.shape
等于 (2, 3, 4)
,表达式 v[1, :, :]
可以重写为 v[1, ...]
.
我进行了一些测试,使用 batch_output[batch_mask, ...]
或 batch_output[batch_mask]
似乎效果相同:
t = torch.arange(24).reshape(2, 3, 4)
# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])
print(torch.all(t[mask] == t[mask, ...])) # returns True
我在 BiLSTM 的实现中看到了这行代码:
batch_output = batch_output[batch_mask, ...]
我认为这是某种“屏蔽”操作,但在 Google 上找不到关于 ...
含义的信息。请帮忙:).
原码:
class BiLSTM(nn.Module):
def __init__(self, vocab_size, tagset, embedding_dim, hidden_dim,
num_layers, bidirectional, dropout, pretrained=None):
# irrelevant code ..........
def forward(self, batch_input, batch_input_lens, batch_mask):
batch_size, padding_length = batch_input.size()
batch_input = self.word_embeds(batch_input) # size: #batch * padding_length * embedding_dim
batch_input = rnn_utils.pack_padded_sequence(
batch_input, batch_input_lens, batch_first=True)
batch_output, self.hidden = self.lstm(batch_input, self.hidden)
self.repackage_hidden(self.hidden)
batch_output, _ = rnn_utils.pad_packed_sequence(batch_output, batch_first=True)
batch_output = batch_output.contiguous().view(batch_size * padding_length, -1)
####### HERE ##########
batch_output = batch_output[batch_mask, ...]
#########################
out = self.hidden2tag(batch_output)
return out
此语句使用 batch_mask
包含的索引屏蔽了 batch_output
的第一个维度。实际上,这意味着您要从批处理中选择一些元素。
这是一个实际的例子:
>>> x = torch.rand(3,1,4,4)
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.5471, 0.5031, 0.3906, 0.7554],
[0.1895, 0.3985, 0.7083, 0.7849],
[0.3128, 0.6733, 0.9223, 0.5345],
[0.2689, 0.9876, 0.1092, 0.7405]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
有掩码和掩码操作:
>>> mask = [0, 2]
>>> x[mask]
tensor([[[[0.5216, 0.1122, 0.0396, 0.5824],
[0.7685, 0.5583, 0.2817, 0.9678],
[0.8878, 0.9477, 0.2554, 0.8261],
[0.2708, 0.3403, 0.7734, 0.2584]]],
[[[0.9834, 0.0276, 0.7114, 0.2872],
[0.3483, 0.2104, 0.1816, 0.5615],
[0.4323, 0.5329, 0.9198, 0.8647],
[0.9054, 0.5763, 0.7939, 0.8388]]]])
仅保留索引 0
和 2
处的元素。
注意:x[mask]
与 x[mask, ...]
相同,其中省略号不是必需的,因为默认情况下所有定位的维度都将选择其所有索引。
我假设 batch_mask
是一个布尔张量。在这种情况下,batch_output[batch_mask]
执行布尔索引,选择对应于 batch_mask
.
True
的元素
...
通常被称为 ellipsis,对于 PyTorch(还有其他 NumPy-like 库),它是 shorthand 以避免多次重复列运算符 (:
)。例如,给定 tensor
v
,其中 v.shape
等于 (2, 3, 4)
,表达式 v[1, :, :]
可以重写为 v[1, ...]
.
我进行了一些测试,使用 batch_output[batch_mask, ...]
或 batch_output[batch_mask]
似乎效果相同:
t = torch.arange(24).reshape(2, 3, 4)
# mask.shape == (2, 3)
mask = torch.tensor([[False, True, True], [True, False, False]])
print(torch.all(t[mask] == t[mask, ...])) # returns True