torch.cat的逆运算

Reverse operation of torch.cat

假设我有一个像 [A,A,A,A...,A].

这样的张量

如何快速获取[[A],[A],[A],...,[A]]作为torch中的张量?

the doc 开始,使用 splitreshape

>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
         [2, 3]]),
 tensor([[4, 5],
         [6, 7]]),
 tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
 tensor([[2, 3],
         [4, 5],
         [6, 7],
         [8, 9]]))

您可以使用 torch.chunk 作为 cat 的反函数,但看起来您想要 unsqueeze(1):

A = torch.randn(2, 3)
A_rep = (A, A, A, A, A, A, A, A)
catted = torch.cat(A_rep)

#uncatted = torch.chunk(catted, len(A_rep))
catted.unsqueeze(1)

您正在寻找 torch.unbind 完全符合您的要求。

import torch

tensor = torch.rand(3, 4, 5)  # tensor of shape (3, 4, 5)
l = tensor.unbind(dim=0)  # list of 3 tensors of shape (4, 5)

你可以沿着你想要的任何维度解除绑定,默认是暗淡的0。你可以使用python解构来快速获取变量中的张量:

a, b, c = tensor.unbind(dim=0)  # 3 tensors of shape (4, 5)