使用条件语句加速 Python 嵌套循环
Speed up Python nested loops with conditional statements
我正在将代码从 MATLAB 转换为 python 以加快简单操作的速度。我写了一个包含嵌套循环和条件语句的函数;循环的目的是 return 与数组 y 比较时数组 x 中最近元素的索引列表。我正在按 1e5 项的顺序进行比较,这大约需要 30 秒到 运行。任何有助于加快此过程的帮助将不胜感激!我在使用 numba-pro 自动即时编译器方面取得了部分成功:
@autojit()
def find_nearest(x,y,idx):
idx_old = 0
rng1 = range(y.shape[0])
rng2 = range(x.shape[0])
for i in rng1:
prev = abs(x[idx_old]-y[i])
for j in rng2:
if abs(x[j]-y[i]) < prev:
prev = abs(x[j]-y[i])
idx_old = j
idx[i] = idx_old
return idx
抱歉我是个菜鸟,我是 python 的新手!
我找到了解决问题的临时方法。通过实施 scipy.spatial 的 kdtree,我能够将 运行 时间从 32 秒减少到不到 10 秒。这仍然比 MATLAB knnsearch 算法慢四倍;了解如何使用条件语句加速循环仍然很重要。但目前这个修改后的实现速度更快:
from scipy import spatial
from numpy import matrix
tree = spatial.KDTree(matrix(x).T)
(_, idxx) = tree.query(matrix(y).T)
数组 x 和 y 是平面一维格式;树要求查询采用列向量形式。
如有任何改进 运行 原始实施时间的建议,我们将不胜感激!
您的 Numba 代码没有任何问题,只是该算法没有达到应有的效率。更好的方法是对 x
数组进行排序并进行二进制搜索,非常类似于 this answer and also this answer:
def find_nearest(x, y):
indices = np.argsort(x)
loc = np.searchsorted(x[indices], y)
right = indices.take(loc, mode='clip')
left = indices.take(loc-1, mode='clip')
return np.where(abs(y-x[left]) < abs(y-x[right]), left, right)
在我的 PC 上,对于具有 106 和 10 的 x
和 y
,这比 KDTree 方法快大约 80 倍分别有 5 个元素。大约三分之二的时间花在 argsort
-ing 数组上,所以我认为您在这里使用 Numba 不会有太多收获。
我正在将代码从 MATLAB 转换为 python 以加快简单操作的速度。我写了一个包含嵌套循环和条件语句的函数;循环的目的是 return 与数组 y 比较时数组 x 中最近元素的索引列表。我正在按 1e5 项的顺序进行比较,这大约需要 30 秒到 运行。任何有助于加快此过程的帮助将不胜感激!我在使用 numba-pro 自动即时编译器方面取得了部分成功:
@autojit()
def find_nearest(x,y,idx):
idx_old = 0
rng1 = range(y.shape[0])
rng2 = range(x.shape[0])
for i in rng1:
prev = abs(x[idx_old]-y[i])
for j in rng2:
if abs(x[j]-y[i]) < prev:
prev = abs(x[j]-y[i])
idx_old = j
idx[i] = idx_old
return idx
抱歉我是个菜鸟,我是 python 的新手!
我找到了解决问题的临时方法。通过实施 scipy.spatial 的 kdtree,我能够将 运行 时间从 32 秒减少到不到 10 秒。这仍然比 MATLAB knnsearch 算法慢四倍;了解如何使用条件语句加速循环仍然很重要。但目前这个修改后的实现速度更快:
from scipy import spatial
from numpy import matrix
tree = spatial.KDTree(matrix(x).T)
(_, idxx) = tree.query(matrix(y).T)
数组 x 和 y 是平面一维格式;树要求查询采用列向量形式。
如有任何改进 运行 原始实施时间的建议,我们将不胜感激!
您的 Numba 代码没有任何问题,只是该算法没有达到应有的效率。更好的方法是对 x
数组进行排序并进行二进制搜索,非常类似于 this answer and also this answer:
def find_nearest(x, y):
indices = np.argsort(x)
loc = np.searchsorted(x[indices], y)
right = indices.take(loc, mode='clip')
left = indices.take(loc-1, mode='clip')
return np.where(abs(y-x[left]) < abs(y-x[right]), left, right)
在我的 PC 上,对于具有 106 和 10 的 x
和 y
,这比 KDTree 方法快大约 80 倍分别有 5 个元素。大约三分之二的时间花在 argsort
-ing 数组上,所以我认为您在这里使用 Numba 不会有太多收获。