创建一个 Numpy 矩阵存储输入 ndarray 的混洗版本

Create a Numpy matrix storing shuffled versions of an input ndarray

我有一个名为 weights 的二维 ndarray,形状为 (npts, nweights)。对于 weights 的每个 ,我希望随机打乱行。我想重复此过程 num_shuffles 次,并将洗牌集合存储到名为 weights_matrix 的 3d ndarray 中。重要的是,对于每次改组迭代,weights 的每一列的改组索引应该相同。

下面显示了该算法的一个显式朴素双循环实现。是否可以避免 python 循环并在纯 Numpy 中生成 weights_matrix

import numpy as np 
npts, nweights = 5, 2
weights = np.random.rand(npts*nweights).reshape((npts, nweights))

num_shuffles = 3
weights_matrix = np.zeros((num_shuffles, npts, nweights))
for i in range(num_shuffles):
    indx = np.random.choice(np.arange(npts), npts, replace=False)
    for j in range(nweights):
        weights_matrix[i, :, j] = weights[indx, j]

您可以先用原始权重的副本填充 3-D 数组,然后对该 3-D 数组的切片执行简单的迭代,使用 numpy.random.shuffle 将每个 2-D 切片打乱-地方。

For every column of weights, I wish to randomly shuffle the rows...the shuffling indices of each column of weights should be the same

只是 "I want to randomly reorder the rows of a 2D array" 的另一种说法。 numpy.random.shufflerandom.shuffle 的支持 numpy 数组的版本:它将就地重新排序容器的元素。这就是您所需要的,因为从这个意义上说,二维 numpy 数组的 "elements" 就是它的行。

import numpy
weights = numpy.array( [ [ 1, 2, 3 ], [ 4, 5, 6], [ 7, 8, 9 ] ] )
weights_3d = weights[ numpy.newaxis, :, : ].repeat( 10, axis=0 )

for w in weights_3d:
    numpy.random.shuffle( w )  # in-place shuffle of the rows of each slice

print( weights_3d[0, :, :] )
print( weights_3d[1, :, :] )
print( weights_3d[2, :, :] )

这是一个矢量化的解决方案,其想法借鉴自 -

weights[np.random.rand(num_shuffles,weights.shape[0]).argsort(1)]

样本运行-

In [28]: weights
Out[28]: 
array([[ 0.22508764,  0.8527072 ],
       [ 0.31504052,  0.73272155],
       [ 0.73370203,  0.54889059],
       [ 0.87470619,  0.12394942],
       [ 0.20587307,  0.11385946]])

In [29]: num_shuffles = 3

In [30]: weights[np.random.rand(num_shuffles,weights.shape[0]).argsort(1)]
Out[30]: 
array([[[ 0.87470619,  0.12394942],
        [ 0.20587307,  0.11385946],
        [ 0.22508764,  0.8527072 ],
        [ 0.31504052,  0.73272155],
        [ 0.73370203,  0.54889059]],

       [[ 0.87470619,  0.12394942],
        [ 0.22508764,  0.8527072 ],
        [ 0.73370203,  0.54889059],
        [ 0.20587307,  0.11385946],
        [ 0.31504052,  0.73272155]],

       [[ 0.73370203,  0.54889059],
        [ 0.31504052,  0.73272155],
        [ 0.22508764,  0.8527072 ],
        [ 0.20587307,  0.11385946],
        [ 0.87470619,  0.12394942]]])