有没有pytorch函数可以把张量的特定连续维度合二为一?
Is there any pytorch function can combine the specific continuous dimensions of tensor into one?
我们把我要找的函数叫做“magic_combine
”,它可以结合我给它的张量的连续维度。更具体地说,我希望它做以下事情:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
我知道 torch.view()
可以做类似的事情。但我只是想知道是否有更优雅的方式来实现目标?
我不确定你对 "a more elegant way" 有什么想法,但是 Tensor.view()
的优点是不会为视图重新分配数据(原始张量和视图共享相同的数据),使这个操作非常轻量级。
正如@UmangGupta 所提到的,包装这个函数来实现你想要的是相当直接的,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])
在我看来比 简单一点,并且不通过 list
构造函数(3 次)。
也可以使用 torch einops。
> pip install einops
from einops import rearrange
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')
有一个 flatten
的变体,它接受 start_dim
和 end_dim
参数。您可以像 magic_combine
一样调用它(除了 end_dim
是包含在内的)。
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.flatten(2, 4) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
https://pytorch.org/docs/stable/generated/torch.flatten.html
还有一个相应的 unflatten
,您可以在其中指定要展开的尺寸和要展开的形状。
我们把我要找的函数叫做“magic_combine
”,它可以结合我给它的张量的连续维度。更具体地说,我希望它做以下事情:
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.magic_combine(2, 5) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
我知道 torch.view()
可以做类似的事情。但我只是想知道是否有更优雅的方式来实现目标?
我不确定你对 "a more elegant way" 有什么想法,但是 Tensor.view()
的优点是不会为视图重新分配数据(原始张量和视图共享相同的数据),使这个操作非常轻量级。
正如@UmangGupta 所提到的,包装这个函数来实现你想要的是相当直接的,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.view(*a.shape[:2], -1, *a.shape[5:])
在我看来比 list
构造函数(3 次)。
也可以使用 torch einops。
> pip install einops
from einops import rearrange
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = rearrange(a, 'd0 d1 d2 d3 d4 d5 -> d0 d1 (d2 d3 d4) d5')
有一个 flatten
的变体,它接受 start_dim
和 end_dim
参数。您可以像 magic_combine
一样调用它(除了 end_dim
是包含在内的)。
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = a.flatten(2, 4) # combine dimension 2, 3, 4
print(b.size()) # should be (1, 2, 60, 6)
https://pytorch.org/docs/stable/generated/torch.flatten.html
还有一个相应的 unflatten
,您可以在其中指定要展开的尺寸和要展开的形状。