在 Python 中移动矩阵行的最快方法

Fastest way to shift rows of matrix in Python

我有一个这样的 4x4 矩阵:

1  2  3  4
5  6  7  8
9  10 11 12
13 14 15 16

我想将每一行向左移动(向左循环移动),移动量为行索引。 IE。第 0 行保持原样,第 1 行左移 1,第 2 行左移 2,等等

所以我们得到这个:

1  2  3  4
6  7  8  5
11 12 9  10
16 13 14 15

我在 Python 中想出的最快方法如下:

import numpy as np
def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

我需要在数千个 4x4 矩阵上 运行 这个函数,所以速度很重要(在 Python 中尽可能)。我不关心使用其他模块,例如 numpy,我只关心速度。

非常感谢任何帮助!

谢谢!

第一次改进,摆脱列表理解

我假设您的输入始终是 4x4 ndarray。如果没有,您需要适当地修改函数(即添加 np.asarray、检查维度等)。删除列表理解已经提供了很好的加速:

import numpy as np

a = np.arange(16).reshape(4, 4)

def ShiftRows(x):
    x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
    return x

def shift(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x

%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

请记住,这两种变体都会就地修改数组。如果这不是您想要的,请在两个函数的开头添加一个 x = x.copy()

根据我的测试 numpy.roll 比任何一个版本都慢得多。

第二次改进,使用numba

现在,真正的加速出现在我们使用 numba:

import numba

@numba.njit
def shift_numba(x):
    for i in range(1, 4):
        x[i] = np.append(x[i, i:], x[i, :i])
    return x    

%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

这比您现在的速度快大约 15 倍。使用 parallel 模式不会提高性能,可能是因为数组的尺寸很小。


测试:展开循环

应 Patrick Artner 的要求,我展开了循环(4x4 很可能):

@numba.njit
def shift_numba_unrolled(x):
    x[1] = np.append(x[1, 1:], x[1, :1])
    x[2] = np.append(x[2, 2:], x[2, :2])
    x[3] = np.append(x[3, 3:], x[3, :3])
    return x

%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

展开似乎产生相同的结果。


编辑:解决了一个大问题,现在加速比现在少了很多。

仅使用列表列表的基本解决方案(对于那些在没有考虑 numpy 的情况下发现这个问题的人):

import numpy as np

def SimpleShift(x):
    for i in range(1,4):
        # inplace slicing
        x[i][:] = x[i][i:] + x[i][:i]
    return x

def EvenSimplerShift(x):
    # manually unrolled loop
    x[1][:] = x[1][1:] + x[1][:1]
    x[2][:] = x[2][2:] + x[2][:2]
    x[3][:] = x[3][3:] + x[3][:3]
    return x

from timeit import timeit

data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]

print(data)
print(SimpleShift(data))
print(EvenSimplerShift(data))

print(timeit(lambda:SimpleShift(data)))
print(timeit(lambda: EvenSimplerShift(data)))

获得

[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]

4.8055571                 # timing with for loop
4.098531100000001         # timing with unrolled loop

所以手动展开循环似乎更快。您可能还想使用 numpy 查看一下。

这个有效:

import numpy as np


def stepped_roll(arr):
    return np.array([np.roll(row, -n) for n, row in enumerate(arr)])


print(stepped_roll(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])))

我更喜欢使用 np.roll,因为 numpy 例程往往比您在 Python 中可以做的更快。遗憾的是,np.apply_along_axis 在这里不起作用,因为你需要每一行的索引。

虽然在你的情况下,操作是如此微不足道,数据集如此之小,@JanChristophTerasa 的回答中建议的 shift() 函数会更快。

如果您不介意硬编码数组大小,在我的测试中,硬编码重排模式的速度大约是硬编码的 6 倍:

def rot(a):
    return a.take((0,1,2,3,5,6,7,4,10,11,8,9,15,12,13,14)).reshape(4, 4)