在 PyTorch 中沿给定轴改组

Shuffling along a given axis in PyTorch

我有一个数据集加载了以下维度 [batch_size, seq_len, n_features](例如 torch.Size([16, 600, 130]))。

我希望能够沿序列长度 axis=1 打乱这些数据,而不改变 PyTorch 中的批量排序或特征向量排序。

进一步说明:为了举例说明,假设我的批量大小为 3,序列长度为 3,特征数为 2。

示例: tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]]) 我希望能够按照以下方式随机洗牌:

tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])

是否有任何 PyTorch 函数可以自动为我执行此操作,或者有人知道什么是实现此操作的好方法吗?

您可以使用 torch.randperm.

对于张量t,你可以使用:

t[:,torch.randperm(t.shape[1]),:]

以你的例子为例:

>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
         [2, 2],
         [3, 3]],

        [[4, 4],
         [5, 5],
         [6, 6]],

        [[7, 7],
         [8, 8],
         [9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
         [3, 3],
         [1, 1]],

        [[5, 5],
         [6, 6],
         [4, 4]],

        [[8, 8],
         [9, 9],
         [7, 7]]])

沿轴随机播放的两部分答案。

  • 首先,直接解决方案为轴 1 的每一“行”提供不同的随机排列。
  • 其次,用于随机排列任何轴的通用随机排列“行”函数。

旁注 1: 抱歉,我的回答晚了几个月 - 我自己也有这个问题,但我无法在网上找到解决问题的简单方法,所以在这里是。

旁注 2: 如前所述,@GoodDeeds 的回答很好,给出了跨其他轴的 相同 随机排列。这给出了跨其他轴的不同排列。

首先,axis=1的直观示例:

输入:

>>> a
tensor([[[1, 1],
         [2, 2],
         [3, 3]],

        [[4, 4],
         [5, 5],
         [6, 6]],

        [[7, 7],
         [8, 8],
         [9, 9]]])

Select 轴 1 的随机“行”。

>>> z = torch.rand(a.shape[:2]).argsort(1)  # define random "row" indices
>>> z = z.unsqueeze(-1).repeat(1, 1, *(a.shape[2:]))  # reformat this for the gather operation.  Note that this works only for dim=1.
>>> output = a.gather(1, z)

输出:

>>> output
tensor([[[2, 2],
         [3, 3],
         [1, 1]],

        [[5, 5],
         [6, 6],
         [4, 4]],

        [[8, 8],
         [9, 9],
         [7, 7]]])

其次,对任何轴的泛化:

如果 PyTorch 的标准库中有这个函数就好了。我会提出一个问题 link 到这个 post.

def shufflerow(tensor, axis):
    row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis)  # get permutation indices
    for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
    row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:]))  # reformat this for the gather operation
    return tensor.gather(axis, row_perm)

示例:

>>> x = torch.arange(2*3*4).reshape(2,3,4)
>>> x
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

随机轴 0:

>>> shufflerow(x, 0)
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]],

        [[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]]])

随机轴 1

>>> shufflerow(x, 1)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[16, 17, 18, 19],
         [12, 13, 14, 15],
         [20, 21, 22, 23]]])

随机轴 2

>>> shufflerow(x, 2)
tensor([[[ 2,  0,  1,  3],
         [ 5,  6,  7,  4],
         [11, 10,  9,  8]],

        [[15, 14, 13, 12],
         [18, 17, 19, 16],
         [23, 20, 22, 21]]])