在 PyTorch 中沿给定轴改组
Shuffling along a given axis in PyTorch
我有一个数据集加载了以下维度 [batch_size, seq_len, n_features]
(例如 torch.Size([16, 600, 130]))。
我希望能够沿序列长度 axis=1
打乱这些数据,而不改变 PyTorch 中的批量排序或特征向量排序。
进一步说明:为了举例说明,假设我的批量大小为 3,序列长度为 3,特征数为 2。
示例:
tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
我希望能够按照以下方式随机洗牌:
tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])
是否有任何 PyTorch 函数可以自动为我执行此操作,或者有人知道什么是实现此操作的好方法吗?
您可以使用 torch.randperm
.
对于张量t
,你可以使用:
t[:,torch.randperm(t.shape[1]),:]
以你的例子为例:
>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
沿轴随机播放的两部分答案。
- 首先,直接解决方案为轴 1 的每一“行”提供不同的随机排列。
- 其次,用于随机排列任何轴的通用随机排列“行”函数。
旁注 1: 抱歉,我的回答晚了几个月 - 我自己也有这个问题,但我无法在网上找到解决问题的简单方法,所以在这里是。
旁注 2: 如前所述,@GoodDeeds 的回答很好,给出了跨其他轴的 相同 随机排列。这给出了跨其他轴的不同排列。
首先,axis=1的直观示例:
输入:
>>> a
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
Select 轴 1 的随机“行”。
>>> z = torch.rand(a.shape[:2]).argsort(1) # define random "row" indices
>>> z = z.unsqueeze(-1).repeat(1, 1, *(a.shape[2:])) # reformat this for the gather operation. Note that this works only for dim=1.
>>> output = a.gather(1, z)
输出:
>>> output
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
其次,对任何轴的泛化:
如果 PyTorch 的标准库中有这个函数就好了。我会提出一个问题 link 到这个 post.
def shufflerow(tensor, axis):
row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis) # get permutation indices
for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:])) # reformat this for the gather operation
return tensor.gather(axis, row_perm)
示例:
>>> x = torch.arange(2*3*4).reshape(2,3,4)
>>> x
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
随机轴 0:
>>> shufflerow(x, 0)
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
随机轴 1
>>> shufflerow(x, 1)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
随机轴 2
>>> shufflerow(x, 2)
tensor([[[ 2, 0, 1, 3],
[ 5, 6, 7, 4],
[11, 10, 9, 8]],
[[15, 14, 13, 12],
[18, 17, 19, 16],
[23, 20, 22, 21]]])
我有一个数据集加载了以下维度 [batch_size, seq_len, n_features]
(例如 torch.Size([16, 600, 130]))。
我希望能够沿序列长度 axis=1
打乱这些数据,而不改变 PyTorch 中的批量排序或特征向量排序。
进一步说明:为了举例说明,假设我的批量大小为 3,序列长度为 3,特征数为 2。
示例:
tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
我希望能够按照以下方式随机洗牌:
tensor([[[3,3],[1,1],[2,2]],[[6,6],[5,5],[4,4]],[[8,8],[7,7],[9,9]]])
是否有任何 PyTorch 函数可以自动为我执行此操作,或者有人知道什么是实现此操作的好方法吗?
您可以使用 torch.randperm
.
对于张量t
,你可以使用:
t[:,torch.randperm(t.shape[1]),:]
以你的例子为例:
>>> t = torch.tensor([[[1,1],[2,2],[3,3]],[[4,4],[5,5],[6,6]],[[7,7],[8,8],[9,9]]])
>>> t
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
>>> t[:,torch.randperm(t.shape[1]),:]
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
沿轴随机播放的两部分答案。
- 首先,直接解决方案为轴 1 的每一“行”提供不同的随机排列。
- 其次,用于随机排列任何轴的通用随机排列“行”函数。
旁注 1: 抱歉,我的回答晚了几个月 - 我自己也有这个问题,但我无法在网上找到解决问题的简单方法,所以在这里是。
旁注 2: 如前所述,@GoodDeeds 的回答很好,给出了跨其他轴的 相同 随机排列。这给出了跨其他轴的不同排列。
首先,axis=1的直观示例:
输入:
>>> a
tensor([[[1, 1],
[2, 2],
[3, 3]],
[[4, 4],
[5, 5],
[6, 6]],
[[7, 7],
[8, 8],
[9, 9]]])
Select 轴 1 的随机“行”。
>>> z = torch.rand(a.shape[:2]).argsort(1) # define random "row" indices
>>> z = z.unsqueeze(-1).repeat(1, 1, *(a.shape[2:])) # reformat this for the gather operation. Note that this works only for dim=1.
>>> output = a.gather(1, z)
输出:
>>> output
tensor([[[2, 2],
[3, 3],
[1, 1]],
[[5, 5],
[6, 6],
[4, 4]],
[[8, 8],
[9, 9],
[7, 7]]])
其次,对任何轴的泛化:
如果 PyTorch 的标准库中有这个函数就好了。我会提出一个问题 link 到这个 post.
def shufflerow(tensor, axis):
row_perm = torch.rand(tensor.shape[:axis+1]).argsort(axis) # get permutation indices
for _ in range(tensor.ndim-axis-1): row_perm.unsqueeze_(-1)
row_perm = row_perm.repeat(*[1 for _ in range(axis+1)], *(tensor.shape[axis+1:])) # reformat this for the gather operation
return tensor.gather(axis, row_perm)
示例:
>>> x = torch.arange(2*3*4).reshape(2,3,4)
>>> x
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
随机轴 0:
>>> shufflerow(x, 0)
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]],
[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]])
随机轴 1
>>> shufflerow(x, 1)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[16, 17, 18, 19],
[12, 13, 14, 15],
[20, 21, 22, 23]]])
随机轴 2
>>> shufflerow(x, 2)
tensor([[[ 2, 0, 1, 3],
[ 5, 6, 7, 4],
[11, 10, 9, 8]],
[[15, 14, 13, 12],
[18, 17, 19, 16],
[23, 20, 22, 21]]])