为什么 numba 不提高我的背包功能的速度?

Why numba don't improve the speed of my knapsack function?

我试图用 numba 加速我的代码,但它似乎不起作用。该程序与 @jit@njit 或纯 python 花费相同的时间(大约 10 秒)。但是我使用了 numpy 而不是 list 或 dict。

这是我的代码:

import numpy as np
from numba import njit
import random
import line_profiler
import atexit
profile = line_profiler.LineProfiler()
atexit.register(profile.print_stats)

@njit
def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N =  np.full((n+1,W+1,W+1),0)
    M =  np.full((n+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

@profile
def main():

    size = 1000
    val = [random.randint(1, size) for i in range(0, size)]
    wt = [random.randint(1, size) for i in range(0, size)]
    W = 1000
    n = len(val)
    a = knapSack(W, wt, val, n)
main()

事实上,如果不改变方法本身,就不可能真正提高当前算法的性能。

您的 N 数组包含大约 10 亿个对象 (1001 * 1001 * 1001)。你需要设置每个元素,所以你至少有十亿次操作。为了获得下限,我们假设设置一个数组元素需要一纳秒(实际上需要更多时间)。 10 亿次操作,每次需要 1 纳秒意味着需要 1 秒才能完成。正如我所说,每次操作可能需要比 1 纳秒更长的时间,所以我们假设它需要 10 纳秒(可能有点高但比 1 纳秒更现实),这意味着我们总共有 10 秒的算法。

因此您输入的预期 运行 时间将在 1 秒到 10 秒之间。因此,如果您的 Python 版本需要 10 秒,它可能已经达到了您选择的方法所能达到的极限,并且没有任何工具可以(显着)改善 运行 时间。


可以使速度更快的一件事是使用 np.zeros 而不是 np.full:

K = np.zeros((n+1, W+1), dtype=int)
N = np.zeros((n+1, W+1, W+1), dtype=int)

并且不要创建 M,因为您不会使用它。


因为你已经使用过 line-profiler 我决定看一看,我得到了这个结果:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1      19137.0  19137.0      0.0      K = np.full((n+1,W+1),0)
     5         1   19408592.0 19408592.0     28.1      N = np.full((n+1,W+1,W+1),0)
     6                                           
     7      1002       6412.0      6.4      0.0      for i in range(n+1):
     8   1003002    4186311.0      4.2      6.1          for w in range(W+1):
     9   1002001    4644031.0      4.6      6.7              if i==0 or w==0:
    10      2001      19663.0      9.8      0.0                  K[i][w] = 0
    11   1000000    5474080.0      5.5      7.9              elif wt[i-1] <= w:
    12    498365    9616406.0     19.3     13.9                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     52596     902030.0     17.2      1.3                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     52596     578740.0     11.0      0.8                      c = N[i-1][w-wt[i-1]]
    15     52596     295980.0      5.6      0.4                      c[i] = i
    16     52596    1239792.0     23.6      1.8                      N[i][w] = c
    17                                                           else:
    18    445769    5100917.0     11.4      7.4                      K[i][w] = K[i-1][w]
    19    445769   11677683.0     26.2     16.9                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    501635    5801328.0     11.6      8.4                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         14.0     14.0      0.0      return N[n][W]

这表明瓶颈是np.fullN[i][w] = N[i-1][w]if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w])。 Numba 不会改进前两个,因为它们已经使用了高度优化的 NumPy 代码,而 numba 更有可能在这些方面变慢。 Numba 可能会改进 if(val[i-1] + K[i-1][w-wt[i-1]] > K[i-1][w]) 但这可能不会引起注意。

如果 np.fullnp.zeros 替换,配置文件会略有变化:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
     3                                           def knapSack(W, wt, val, n):
     4         1        747.0    747.0      0.0      K = np.zeros((n+1, W+1),dtype=int)
     5         1     109592.0 109592.0      0.2      N = np.zeros((n+1, W+1, W+1),dtype=int)
     6                                           
     7      1002       4230.0      4.2      0.0      for i in range(n+1):
     8   1003002    4414071.0      4.4      7.0          for w in range(W+1):
     9   1002001    4836807.0      4.8      7.7              if i==0 or w==0:
    10      2001      22282.0     11.1      0.0                  K[i][w] = 0
    11   1000000    5646859.0      5.6      8.9              elif wt[i-1] <= w:
    12    521222   10389581.0     19.9     16.5                  if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
    13     47579     784563.0     16.5      1.2                      K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
    14     47579     509056.0     10.7      0.8                      c = N[i-1][w-wt[i-1]]
    15     47579     362796.0      7.6      0.6                      c[i] = i
    16     47579    1975916.0     41.5      3.1                      N[i][w] = c
    17                                                           else:
    18    473643    5579823.0     11.8      8.8                      K[i][w] = K[i-1][w]
    19    473643   22805846.0     48.1     36.1                      N[i][w] = N[i-1][w]
    20                                                       else:
    21    478778    5664271.0     11.8      9.0                  K[i][w] = K[i-1][w]
    22         1         16.0     16.0      0.0      N[n][W][0] = K[n][W]
    23         1         10.0     10.0      0.0      return N[n][W]

但主要瓶颈仍然是 N[i][w] = N[i-1][w],numba 可能比纯 NumPy 慢。因此,您在代码的其他一些部分使用 numba 获得的改进可能不会(再次)引起注意。


对于第一个配置文件,我使用了这个版本的代码(第二个配置文件只是将 np.full 更改为 np.zeros):

import numpy as np

def knapSack(W, wt, val, n):
    K = np.full((n+1,W+1),0)
    N = np.full((n+1,W+1,W+1),0)

    for i in range(n+1):
        for w in range(W+1):
            if i==0 or w==0:
                K[i][w] = 0
            elif wt[i-1] <= w:
                if(val[i-1] + K[i-1][w-wt[i-1]] >  K[i-1][w]):
                    K[i][w] = val[i-1] + K[i-1][w-wt[i-1]]
                    c = N[i-1][w-wt[i-1]]
                    c[i] = i
                    N[i][w] = c
                else:
                    K[i][w] = K[i-1][w]
                    N[i][w] = N[i-1][w]
            else:
                K[i][w] = K[i-1][w]
    N[n][W][0] = K[n][W]
    return N[n][W]

import random
size = 1000
val = [random.randint(1, size) for i in range(0, size)]
wt = [random.randint(1, size) for i in range(0, size)]
W = 1000
n = len(val)

%lprun -f knapSack knapSack(W, wt, val, n)

这里是新函数:

 @njit
    def knapSack(W, wt, val, n):

        K = np.zeros((n + 1, W + 1),dtype=np.int32)
        # In fact we must only save the previous combinations and the current, 
        # not all :) So N is considerably reduce
        N = np.zeros((2, W + 1, W + 1),dtype=np.int32)

        for i in range(n + 1):
            for w in range(W + 1):
                if i == 0 or w == 0:
                    K[i][w] = 0
                elif wt[i - 1] <= w:
                    if val[i - 1] + K[i - 1][w - wt[i - 1]] > K[i - 1][w]:
                        K[i][w] = val[i - 1] + K[i - 1][w - wt[i - 1]]
                        N[i%2][w] = np.copy(N[(i - 1)%2][w - wt[i - 1]])
                        N[i%2][w][i] = i
                    else:
                        K[i][w] = K[i - 1][w]
                        N[i%2][w] = N[(i - 1)%2][w]
                else:
                    K[i][w] = K[i - 1][w]
        N[(n)%2][W][0] = K[n][W]
        return N[(n)%2][W]

非常感谢 MSeifert !!