Python - 如何加速 for 循环从另一个 numpy 数组计算创建一个 numpy 数组
Python - how to speed up a for loop creating a numpy array from another numpy array calculation
首先,对于含糊不清的标题,我深表歉意,我想不出适合这个问题的名称。
我有 3 个 numpy 数组,格式如下:
N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... 几十万个元素长
e1 = [1, 0, 0]
e2 = [0, 1, 0]
想法是创建第四个数组 'v',其维度应与 'N' 相同,但将根据 if 语句赋值。这是我目前拥有的,应该可以更好地解释这个问题:
v = np.zeros([len(N), 3])
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
此代码可以满足我的要求,但需要比预期更长的时间(> 5 分钟)。我可以使用任何形式的列表理解或类似概念来提高代码效率吗?
您可以使用 numpy.where
替换 if-else 并使用广播对过程进行矢量化,这是 numpy.where
的选项:
import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
这里有一些基准测试:
1) 数据设置:
N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N
#array([[3, 5, 0],
# [5, 0, 8],
# [4, 6, 0],
# ...,
# [9, 4, 2],
# [6, 9, 3],
# [2, 9, 2]])
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])
2) 时间:
def forloop():
v = np.zeros([len(N), 3]);
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
def forloop2():
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0:
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
%timeit forloop()
10 loops, best of 3: 25.5 ms per loop
%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop
%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop
3) 所有方法的结果检查:
v1 = forloop()
v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
v3 = forloop2()
(v3 == v1).all()
# True
(v1 == v2).all()
# True
我不确定您要做什么,但我知道为什么这段 specific 代码对您来说这么慢。最严重的违规者是 (N*e1)
。这是一个简单的计算,它在 numpy 中运行得非常快,但是你在循环内执行它,len(N)
次!
通过将代码拉出循环,我可以在我的机器上用 N == 1000000
在不到 15 秒的时间内执行您的代码。下面的例子。
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
另一个答案演示了如何以可读性稍差的代码为代价来避免 for 循环和 if 语句以获得更多的速度。
首先,对于含糊不清的标题,我深表歉意,我想不出适合这个问题的名称。
我有 3 个 numpy 数组,格式如下:
N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... 几十万个元素长
e1 = [1, 0, 0]
e2 = [0, 1, 0]
想法是创建第四个数组 'v',其维度应与 'N' 相同,但将根据 if 语句赋值。这是我目前拥有的,应该可以更好地解释这个问题:
v = np.zeros([len(N), 3])
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
此代码可以满足我的要求,但需要比预期更长的时间(> 5 分钟)。我可以使用任何形式的列表理解或类似概念来提高代码效率吗?
您可以使用 numpy.where
替换 if-else 并使用广播对过程进行矢量化,这是 numpy.where
的选项:
import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
这里有一些基准测试:
1) 数据设置:
N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N
#array([[3, 5, 0],
# [5, 0, 8],
# [4, 6, 0],
# ...,
# [9, 4, 2],
# [6, 9, 3],
# [2, 9, 2]])
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])
2) 时间:
def forloop():
v = np.zeros([len(N), 3]);
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
def forloop2():
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0:
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
%timeit forloop()
10 loops, best of 3: 25.5 ms per loop
%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop
%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop
3) 所有方法的结果检查:
v1 = forloop()
v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
v3 = forloop2()
(v3 == v1).all()
# True
(v1 == v2).all()
# True
我不确定您要做什么,但我知道为什么这段 specific 代码对您来说这么慢。最严重的违规者是 (N*e1)
。这是一个简单的计算,它在 numpy 中运行得非常快,但是你在循环内执行它,len(N)
次!
通过将代码拉出循环,我可以在我的机器上用 N == 1000000
在不到 15 秒的时间内执行您的代码。下面的例子。
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
另一个答案演示了如何以可读性稍差的代码为代价来避免 for 循环和 if 语句以获得更多的速度。