Numba jitted len() 比纯 Python len() 慢

Numba jitted len() is slower than pure Python len()

我正在学习 numba 并遇到了这种我不理解的 "strange" 行为。 我尝试使用以下代码(在 iPython 中进行计时):

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

def py_len(seq):
    return len(seq)

##
t = np.random.rand(1000)

%timeit nb_len(t)
%timeit py_len(t)

结果如下(实际上是第二个运行由于编译了numba):

258 ns ± 1.37 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
137 ns ± 0.964 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

纯python版本比numba版本快一倍。 我也试过签名 @nb.njit( nb.int32(nb.float64[:]) ) 但结果还是一样。

我是不是哪里弄错了?

谢谢。

增加时间的不是 len() 部分。使用输入参数调用 jit 函数会增加开销,这就是您看到的时差。

import numba as nb

def py_pass(i):
    return i

@nb.njit()
def nb_pass(i):
    return i

%timeit py_pass(1)
%timeit nb_pass(1)

带输入参数的结果

102 ns ± 0.371 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
165 ns ± 0.783 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

有趣的是,如果你不需要向 jit 函数传递任何东西,它会更快:

def py_pass():
    return 1

@nb.njit()
def nb_pass():
    return 1

%timeit py_pass()
%timeit nb_pass()

没有输入参数的结果

96.6 ns ± 0.278 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
75.8 ns ± 0.221 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)

正如 所述,在这种情况下,这不是因为 len 函数,而是因为对 numba 函数的调用实际上比对普通 Python 函数的调用慢.

是什么让 jit-ted 函数与众不同?

要理解为什么调用 numba jitted 函数速度较慢,必须了解 numba jitted 函数不再是函数。这是一个调度程序对象:

import numba as nb
@nb.njit
def nb_len(seq):
    return len(seq)
print(nb_len)  # CPUDispatcher(<function nb_len at 0x0000027EB1B4E798>)

这个CPUDispatcher实例表示(可能)基于修饰函数生成的多个编译函数。

这意味着当您调用 CPUDispatcher 实例时有多个步骤:

  • 获取参数的类型。
  • 如果没有适合这些类型参数的编译函数,请使用参数类型编译装饰函数。
  • 有时:将参数转换为相应的 numba 类型。
  • 调用编译后的函数。

与非修饰函数相比,所有这些步骤都会增加开销。特别是如果没有合适的编译函数并且调度程序需要编译函数 - 或者 - 输入类型需要转换(仅发生在 Python 类型,如:列表,集合,字典)调用 CPUDispatcher 将是慢很多——在编写 numba 0.46 时不推荐使用这些类型,部分原因是,请参阅 "2.11.2. Deprecation of reflection for List and Set types".

你的情况

在您的情况下,由于编译,第一次调用 jitted 函数会慢很多。

任何后续调用只会稍微慢一些,因为 numba 必须获取参数类型,检查是否已经存在编译函数,然后调用该编译函数。有趣的是,额外的时间取决于参数的数量和该函数的已编译 "overloads" 的数量。通常这个额外的时间是微不足道的,因为该函数所做的不仅仅是调用 len.

编译时间

尽管函数非常简单,但第一次调用时的编译会花费大量时间:

import numpy as np
import numba as nb

def first_call(seq):
    @nb.njit
    def nb_len(seq):
        return len(seq)
    return nb_len(seq)

@nb.njit
def _nb_len(seq):
    return len(seq)

def subsequent_calls(seq):
    return _nb_len(seq)

t = np.random.rand(1000)
_nb_len(np.ones(1, dtype=np.float64))

%timeit first_call(t)
# 29.8 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit subsequent_calls(t)
# 384 ns ± 6.02 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

转换时间

此外,如果 numba 需要转换参数,它会慢很多。这仅适用于 numba 无法直接处理的 Python 类型,例如列表:

import numpy as np
import numba as nb

@nb.njit
def nb_len(seq):
    return len(seq)

arr = np.random.rand(10_000)
lst = arr.tolist()

nb_len(arr)
nb_len(lst)

%timeit nb_len(arr)
# 354 ns ± 24 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit nb_len(lst)
# 14.1 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

总结

  • Numba 函数与普通 Python 函数相比有一些额外的开销。所以确保你做 "enough" numba 擅长优化的东西,否则一个普通的 Python 函数会更快、更灵活并且更容易调试。
  • numba 函数中的函数调用与 numba 函数外的函数调用确实不同。所以 nb_len 中的 len()py_len 中的 len() 可以有完全不同的 运行 次。然而,在这种情况下,运行 时间几乎相同。但总的来说,意识到这一点是件好事。
  • 根据参数类型,numba 函数可能(在幕后)非常慢,尤其是在处理 Python 类型作为参数或 return 类型时!