torch.cat的逆运算
Reverse operation of torch.cat
假设我有一个像 [A,A,A,A...,A]
.
这样的张量
如何快速获取[[A],[A],[A],...,[A]]
作为torch中的张量?
从 the doc 开始,使用 split
或 reshape
。
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
您可以使用 torch.chunk
作为 cat
的反函数,但看起来您想要 unsqueeze(1)
:
A = torch.randn(2, 3)
A_rep = (A, A, A, A, A, A, A, A)
catted = torch.cat(A_rep)
#uncatted = torch.chunk(catted, len(A_rep))
catted.unsqueeze(1)
您正在寻找 torch.unbind
完全符合您的要求。
import torch
tensor = torch.rand(3, 4, 5) # tensor of shape (3, 4, 5)
l = tensor.unbind(dim=0) # list of 3 tensors of shape (4, 5)
你可以沿着你想要的任何维度解除绑定,默认是暗淡的0。你可以使用python解构来快速获取变量中的张量:
a, b, c = tensor.unbind(dim=0) # 3 tensors of shape (4, 5)
假设我有一个像 [A,A,A,A...,A]
.
如何快速获取[[A],[A],[A],...,[A]]
作为torch中的张量?
从 the doc 开始,使用 split
或 reshape
。
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
您可以使用 torch.chunk
作为 cat
的反函数,但看起来您想要 unsqueeze(1)
:
A = torch.randn(2, 3)
A_rep = (A, A, A, A, A, A, A, A)
catted = torch.cat(A_rep)
#uncatted = torch.chunk(catted, len(A_rep))
catted.unsqueeze(1)
您正在寻找 torch.unbind
完全符合您的要求。
import torch
tensor = torch.rand(3, 4, 5) # tensor of shape (3, 4, 5)
l = tensor.unbind(dim=0) # list of 3 tensors of shape (4, 5)
你可以沿着你想要的任何维度解除绑定,默认是暗淡的0。你可以使用python解构来快速获取变量中的张量:
a, b, c = tensor.unbind(dim=0) # 3 tensors of shape (4, 5)