为什么 numba 不提高我的背包功能的速度?
Why numba don't improve the speed of my knapsack function?
我试图用 numba 加速我的代码,但它似乎不起作用。该程序与 @jit
、@njit
或纯 python 花费相同的时间(大约 10 秒)。但是我使用了 numpy 而不是 list 或 dict。
这是我的代码:
import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)
@njit
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
M = np.full((n+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
@profile
def main():
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
a = knapSack(W, wt, val, n)
main()
事实上,如果不改变方法本身,就不可能真正提高当前算法的性能。
您的 N
数组包含大约 10 亿个对象 (1001 * 1001 * 1001
)。你需要设置每个元素,所以你至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10 亿次操作,每次需要 1 纳秒意味着需要 1 秒才能完成。正如我所说,每次操作可能需要比 1 纳秒更长的时间,所以我们假设它需要 10 纳秒(可能有点高但比 1 纳秒更现实),这意味着我们总共有 10 秒的算法。
因此您输入的预期 运行 时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,它可能已经达到了您选择的方法所能达到的极限,并且没有任何工具可以(显着)改善 运行 时间。
可以使速度更快的一件事是使用 np.zeros
而不是 np.full
:
K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)
并且不要创建 M
,因为您不会使用它。
因为你已经使用过 line-profiler 我决定看一看,我得到了这个结果:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 19137.0 19137.0 0.0 K = np.full((n+1,W+1),0)
5 1 19408592.0 19408592.0 28.1 N = np.full((n+1,W+1,W+1),0)
6
7 1002 6412.0 6.4 0.0 for i in range(n+1):
8 1003002 4186311.0 4.2 6.1 for w in range(W+1):
9 1002001 4644031.0 4.6 6.7 if i==0 or w==0:
10 2001 19663.0 9.8 0.0 K[i][w] = 0
11 1000000 5474080.0 5.5 7.9 elif wt[i-1] <= w:
12 498365 9616406.0 19.3 13.9 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 52596 902030.0 17.2 1.3 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 52596 578740.0 11.0 0.8 c = N[i-1][w-wt[i-1]]
15 52596 295980.0 5.6 0.4 c[i] = i
16 52596 1239792.0 23.6 1.8 N[i][w] = c
17 else:
18 445769 5100917.0 11.4 7.4 K[i][w] = K[i-1][w]
19 445769 11677683.0 26.2 16.9 N[i][w] = N[i-1][w]
20 else:
21 501635 5801328.0 11.6 8.4 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 14.0 14.0 0.0 return N[n][W]
这表明瓶颈是np.full
、N[i][w] = N[i-1][w]
和if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,而 numba 更有可能在这些方面变慢。 Numba 可能会改进 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
但这可能不会引起注意。
如果 np.full
被 np.zeros
替换,配置文件会略有变化:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 747.0 747.0 0.0 K = np.zeros((n+1, W+1),dtype=int)
5 1 109592.0 109592.0 0.2 N = np.zeros((n+1, W+1, W+1),dtype=int)
6
7 1002 4230.0 4.2 0.0 for i in range(n+1):
8 1003002 4414071.0 4.4 7.0 for w in range(W+1):
9 1002001 4836807.0 4.8 7.7 if i==0 or w==0:
10 2001 22282.0 11.1 0.0 K[i][w] = 0
11 1000000 5646859.0 5.6 8.9 elif wt[i-1] <= w:
12 521222 10389581.0 19.9 16.5 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 47579 784563.0 16.5 1.2 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 47579 509056.0 10.7 0.8 c = N[i-1][w-wt[i-1]]
15 47579 362796.0 7.6 0.6 c[i] = i
16 47579 1975916.0 41.5 3.1 N[i][w] = c
17 else:
18 473643 5579823.0 11.8 8.8 K[i][w] = K[i-1][w]
19 473643 22805846.0 48.1 36.1 N[i][w] = N[i-1][w]
20 else:
21 478778 5664271.0 11.8 9.0 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 10.0 10.0 0.0 return N[n][W]
但主要瓶颈仍然是 N[i][w] = N[i-1][w]
,numba 可能比纯 NumPy 慢。因此,您在代码的其他一些部分使用 numba 获得的改进可能不会(再次)引起注意。
对于第一个配置文件,我使用了这个版本的代码(第二个配置文件只是将 np.full
更改为 np.zeros
):
import numpy as np
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
%lprun -f knapSack knapSack(W, wt, val, n)
这里是新函数:
@njit
def knapSack(W, wt, val, n):
K = np.zeros((n + 1, W + 1),dtype=np.int32)
# In fact we must only save the previous combinations and the current,
# not all :) So N is considerably reduce
N = np.zeros((2, W + 1, W + 1),dtype=np.int32)
for i in range(n + 1):
for w in range(W + 1):
if i == 0 or w == 0:
K[i][w] = 0
elif wt[i - 1] <= w:
if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
N[i%2][w][i] = i
else:
K[i][w] = K[i - 1][w]
N[i%2][w] = N[(i - 1)%2][w]
else:
K[i][w] = K[i - 1][w]
N[(n)%2][W][0] = K[n][W]
return N[(n)%2][W]
非常感谢 MSeifert !!
我试图用 numba 加速我的代码,但它似乎不起作用。该程序与 @jit
、@njit
或纯 python 花费相同的时间(大约 10 秒)。但是我使用了 numpy 而不是 list 或 dict。
这是我的代码:
import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)
@njit
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
M = np.full((n+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
@profile
def main():
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
a = knapSack(W, wt, val, n)
main()
事实上,如果不改变方法本身,就不可能真正提高当前算法的性能。
您的 N
数组包含大约 10 亿个对象 (1001 * 1001 * 1001
)。你需要设置每个元素,所以你至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10 亿次操作,每次需要 1 纳秒意味着需要 1 秒才能完成。正如我所说,每次操作可能需要比 1 纳秒更长的时间,所以我们假设它需要 10 纳秒(可能有点高但比 1 纳秒更现实),这意味着我们总共有 10 秒的算法。
因此您输入的预期 运行 时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,它可能已经达到了您选择的方法所能达到的极限,并且没有任何工具可以(显着)改善 运行 时间。
可以使速度更快的一件事是使用 np.zeros
而不是 np.full
:
K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)
并且不要创建 M
,因为您不会使用它。
因为你已经使用过 line-profiler 我决定看一看,我得到了这个结果:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 19137.0 19137.0 0.0 K = np.full((n+1,W+1),0)
5 1 19408592.0 19408592.0 28.1 N = np.full((n+1,W+1,W+1),0)
6
7 1002 6412.0 6.4 0.0 for i in range(n+1):
8 1003002 4186311.0 4.2 6.1 for w in range(W+1):
9 1002001 4644031.0 4.6 6.7 if i==0 or w==0:
10 2001 19663.0 9.8 0.0 K[i][w] = 0
11 1000000 5474080.0 5.5 7.9 elif wt[i-1] <= w:
12 498365 9616406.0 19.3 13.9 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 52596 902030.0 17.2 1.3 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 52596 578740.0 11.0 0.8 c = N[i-1][w-wt[i-1]]
15 52596 295980.0 5.6 0.4 c[i] = i
16 52596 1239792.0 23.6 1.8 N[i][w] = c
17 else:
18 445769 5100917.0 11.4 7.4 K[i][w] = K[i-1][w]
19 445769 11677683.0 26.2 16.9 N[i][w] = N[i-1][w]
20 else:
21 501635 5801328.0 11.6 8.4 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 14.0 14.0 0.0 return N[n][W]
这表明瓶颈是np.full
、N[i][w] = N[i-1][w]
和if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,而 numba 更有可能在这些方面变慢。 Numba 可能会改进 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])
但这可能不会引起注意。
如果 np.full
被 np.zeros
替换,配置文件会略有变化:
Line # Hits Time Per Hit % Time Line Contents
==============================================================
3 def knapSack(W, wt, val, n):
4 1 747.0 747.0 0.0 K = np.zeros((n+1, W+1),dtype=int)
5 1 109592.0 109592.0 0.2 N = np.zeros((n+1, W+1, W+1),dtype=int)
6
7 1002 4230.0 4.2 0.0 for i in range(n+1):
8 1003002 4414071.0 4.4 7.0 for w in range(W+1):
9 1002001 4836807.0 4.8 7.7 if i==0 or w==0:
10 2001 22282.0 11.1 0.0 K[i][w] = 0
11 1000000 5646859.0 5.6 8.9 elif wt[i-1] <= w:
12 521222 10389581.0 19.9 16.5 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
13 47579 784563.0 16.5 1.2 K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
14 47579 509056.0 10.7 0.8 c = N[i-1][w-wt[i-1]]
15 47579 362796.0 7.6 0.6 c[i] = i
16 47579 1975916.0 41.5 3.1 N[i][w] = c
17 else:
18 473643 5579823.0 11.8 8.8 K[i][w] = K[i-1][w]
19 473643 22805846.0 48.1 36.1 N[i][w] = N[i-1][w]
20 else:
21 478778 5664271.0 11.8 9.0 K[i][w] = K[i-1][w]
22 1 16.0 16.0 0.0 N[n][W][0] = K[n][W]
23 1 10.0 10.0 0.0 return N[n][W]
但主要瓶颈仍然是 N[i][w] = N[i-1][w]
,numba 可能比纯 NumPy 慢。因此,您在代码的其他一些部分使用 numba 获得的改进可能不会(再次)引起注意。
对于第一个配置文件,我使用了这个版本的代码(第二个配置文件只是将 np.full
更改为 np.zeros
):
import numpy as np
def knapSack(W, wt, val, n):
K = np.full((n+1,W+1),0)
N = np.full((n+1,W+1,W+1),0)
for i in range(n+1):
for w in range(W+1):
if i==0 or w==0:
K[i][w] = 0
elif wt[i-1] <= w:
if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]):
K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
c = N[i-1][w-wt[i-1]]
c[i] = i
N[i][w] = c
else:
K[i][w] = K[i-1][w]
N[i][w] = N[i-1][w]
else:
K[i][w] = K[i-1][w]
N[n][W][0] = K[n][W]
return N[n][W]
import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)
%lprun -f knapSack knapSack(W, wt, val, n)
这里是新函数:
@njit
def knapSack(W, wt, val, n):
K = np.zeros((n + 1, W + 1),dtype=np.int32)
# In fact we must only save the previous combinations and the current,
# not all :) So N is considerably reduce
N = np.zeros((2, W + 1, W + 1),dtype=np.int32)
for i in range(n + 1):
for w in range(W + 1):
if i == 0 or w == 0:
K[i][w] = 0
elif wt[i - 1] <= w:
if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
N[i%2][w][i] = i
else:
K[i][w] = K[i - 1][w]
N[i%2][w] = N[(i - 1)%2][w]
else:
K[i][w] = K[i - 1][w]
N[(n)%2][W][0] = K[n][W]
return N[(n)%2][W]
非常感谢 MSeifert !!