在 Python 中移动矩阵行的最快方法
Fastest way to shift rows of matrix in Python
我有一个这样的 4x4 矩阵:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
我想将每一行向左移动(向左循环移动),移动量为行索引。 IE。第 0 行保持原样,第 1 行左移 1,第 2 行左移 2,等等
所以我们得到这个:
1 2 3 4
6 7 8 5
11 12 9 10
16 13 14 15
我在 Python 中想出的最快方法如下:
import numpy as np
def ShiftRows(x):
x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
return x
我需要在数千个 4x4 矩阵上 运行 这个函数,所以速度很重要(在 Python 中尽可能)。我不关心使用其他模块,例如 numpy,我只关心速度。
非常感谢任何帮助!
谢谢!
第一次改进,摆脱列表理解
我假设您的输入始终是 4x4 ndarray。如果没有,您需要适当地修改函数(即添加 np.asarray
、检查维度等)。删除列表理解已经提供了很好的加速:
import numpy as np
a = np.arange(16).reshape(4, 4)
def ShiftRows(x):
x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
return x
def shift(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
请记住,这两种变体都会就地修改数组。如果这不是您想要的,请在两个函数的开头添加一个 x = x.copy()
。
根据我的测试 numpy.roll
比任何一个版本都慢得多。
第二次改进,使用numba
现在,真正的加速出现在我们使用 numba
:
import numba
@numba.njit
def shift_numba(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这比您现在的速度快大约 15 倍。使用 parallel
模式不会提高性能,可能是因为数组的尺寸很小。
测试:展开循环
应 Patrick Artner 的要求,我展开了循环(4x4 很可能):
@numba.njit
def shift_numba_unrolled(x):
x[1] = np.append(x[1, 1:], x[1, :1])
x[2] = np.append(x[2, 2:], x[2, :2])
x[3] = np.append(x[3, 3:], x[3, :3])
return x
%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
展开似乎产生相同的结果。
编辑:解决了一个大问题,现在加速比现在少了很多。
仅使用列表列表的基本解决方案(对于那些在没有考虑 numpy 的情况下发现这个问题的人):
import numpy as np
def SimpleShift(x):
for i in range(1,4):
# inplace slicing
x[i][:] = x[i][i:] + x[i][:i]
return x
def EvenSimplerShift(x):
# manually unrolled loop
x[1][:] = x[1][1:] + x[1][:1]
x[2][:] = x[2][2:] + x[2][:2]
x[3][:] = x[3][3:] + x[3][:3]
return x
from timeit import timeit
data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
print(data)
print(SimpleShift(data))
print(EvenSimplerShift(data))
print(timeit(lambda:SimpleShift(data)))
print(timeit(lambda: EvenSimplerShift(data)))
获得
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
4.8055571 # timing with for loop
4.098531100000001 # timing with unrolled loop
所以手动展开循环似乎更快。您可能还想使用 numpy 查看一下。
这个有效:
import numpy as np
def stepped_roll(arr):
return np.array([np.roll(row, -n) for n, row in enumerate(arr)])
print(stepped_roll(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])))
我更喜欢使用 np.roll
,因为 numpy 例程往往比您在 Python 中可以做的更快。遗憾的是,np.apply_along_axis
在这里不起作用,因为你需要每一行的索引。
虽然在你的情况下,操作是如此微不足道,数据集如此之小,@JanChristophTerasa 的回答中建议的 shift()
函数会更快。
如果您不介意硬编码数组大小,在我的测试中,硬编码重排模式的速度大约是硬编码的 6 倍:
def rot(a):
return a.take((0,1,2,3,5,6,7,4,10,11,8,9,15,12,13,14)).reshape(4, 4)
我有一个这样的 4x4 矩阵:
1 2 3 4
5 6 7 8
9 10 11 12
13 14 15 16
我想将每一行向左移动(向左循环移动),移动量为行索引。 IE。第 0 行保持原样,第 1 行左移 1,第 2 行左移 2,等等
所以我们得到这个:
1 2 3 4
6 7 8 5
11 12 9 10
16 13 14 15
我在 Python 中想出的最快方法如下:
import numpy as np
def ShiftRows(x):
x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
return x
我需要在数千个 4x4 矩阵上 运行 这个函数,所以速度很重要(在 Python 中尽可能)。我不关心使用其他模块,例如 numpy,我只关心速度。
非常感谢任何帮助!
谢谢!
第一次改进,摆脱列表理解
我假设您的输入始终是 4x4 ndarray。如果没有,您需要适当地修改函数(即添加 np.asarray
、检查维度等)。删除列表理解已经提供了很好的加速:
import numpy as np
a = np.arange(16).reshape(4, 4)
def ShiftRows(x):
x[1:] = [np.append(x[i][i:], x[i][:i]) for i in range(1, 4)]
return x
def shift(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit ShiftRows(a)
# 38.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit shift(a)
# 31.9 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
请记住,这两种变体都会就地修改数组。如果这不是您想要的,请在两个函数的开头添加一个 x = x.copy()
。
根据我的测试 numpy.roll
比任何一个版本都慢得多。
第二次改进,使用numba
现在,真正的加速出现在我们使用 numba
:
import numba
@numba.njit
def shift_numba(x):
for i in range(1, 4):
x[i] = np.append(x[i, i:], x[i, :i])
return x
%timeit shift_numba(a)
# 2.5 µs ± 115 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
这比您现在的速度快大约 15 倍。使用 parallel
模式不会提高性能,可能是因为数组的尺寸很小。
测试:展开循环
应 Patrick Artner 的要求,我展开了循环(4x4 很可能):
@numba.njit
def shift_numba_unrolled(x):
x[1] = np.append(x[1, 1:], x[1, :1])
x[2] = np.append(x[2, 2:], x[2, :2])
x[3] = np.append(x[3, 3:], x[3, :3])
return x
%timeit shift_numba_unrolled(a)
# 2.49 µs ± 85 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
展开似乎产生相同的结果。
编辑:解决了一个大问题,现在加速比现在少了很多。
仅使用列表列表的基本解决方案(对于那些在没有考虑 numpy 的情况下发现这个问题的人):
import numpy as np
def SimpleShift(x):
for i in range(1,4):
# inplace slicing
x[i][:] = x[i][i:] + x[i][:i]
return x
def EvenSimplerShift(x):
# manually unrolled loop
x[1][:] = x[1][1:] + x[1][:1]
x[2][:] = x[2][2:] + x[2][:2]
x[3][:] = x[3][3:] + x[3][:3]
return x
from timeit import timeit
data = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
print(data)
print(SimpleShift(data))
print(EvenSimplerShift(data))
print(timeit(lambda:SimpleShift(data)))
print(timeit(lambda: EvenSimplerShift(data)))
获得
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
[[1, 2, 3, 4], [6, 7, 8, 5], [11, 12, 9, 10], [16, 13, 14, 15]]
4.8055571 # timing with for loop
4.098531100000001 # timing with unrolled loop
所以手动展开循环似乎更快。您可能还想使用 numpy 查看一下。
这个有效:
import numpy as np
def stepped_roll(arr):
return np.array([np.roll(row, -n) for n, row in enumerate(arr)])
print(stepped_roll(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])))
我更喜欢使用 np.roll
,因为 numpy 例程往往比您在 Python 中可以做的更快。遗憾的是,np.apply_along_axis
在这里不起作用,因为你需要每一行的索引。
虽然在你的情况下,操作是如此微不足道,数据集如此之小,@JanChristophTerasa 的回答中建议的 shift()
函数会更快。
如果您不介意硬编码数组大小,在我的测试中,硬编码重排模式的速度大约是硬编码的 6 倍:
def rot(a):
return a.take((0,1,2,3,5,6,7,4,10,11,8,9,15,12,13,14)).reshape(4, 4)