如何强制在 Numba 中进行 looplifting?

How to force looplifting in Numba?

我正在尝试实现 numpy.take() 的 Numba 编译版本,但我对 Numba 感到很困惑。

首先,据我所知,在 nopython 模式下无法在 numba 函数内创建新的 ndarray。当前版本的文档似乎没有在任何地方提及它,但我能够在 old Numba v0.15 docs 中找到它。所以我不得不放弃nopython模式。

其次,我不明白 looplifting 是如何工作的,至少在我的测试中是这样。这是我的代码:

from numba import jit
import numpy as np
import time

@jit(forceobj = True)
def _take1(arr, idxs):
    res = np.ndarray((idxs.size,), arr.dtype)
    lastIdx = arr.size - 1
    for i in range(idxs.size):
        idx = idxs[i]
        if idx > lastIdx:
            idx = lastIdx
        elif idx < 0:
            idx = 0
        res[i] = arr[idx]
    return res

def _take2(arr, idxs):
    res = np.ndarray((idxs.size,), arr.dtype)
    lastIdx = arr.size - 1
    for i in range(idxs.size):
        idx = idxs[i]
        if idx > lastIdx:
            idx = lastIdx
        elif idx < 0:
            idx = 0
        res[i] = arr[idx]
    return res

sz = 2000000
arr = np.arange(sz, dtype = np.int32)
idxs = np.arange(sz, 0, -1, dtype = np.int32)

start = time.time()
_take2(arr, idxs)
end = time.time()
print("Elapsed (plain python)      = %s" % (end - start))

start = time.time()
_take1(arr, idxs)
end = time.time()
print("Elapsed (with compilation)  = %s" % (end - start))

start = time.time()
_take1(arr, idxs)
end = time.time()
print("Elapsed (after compilation) = %s" % (end - start))

这段代码在我的笔记本电脑上给出了以下输出:

Elapsed (plain python)      = 0.8870017528533936
Elapsed (with compilation)  = 1.7350387573242188
Elapsed (after compilation) = 1.1779978275299072

所以编译版本比纯版本慢python。我的猜测是,这是因为它没有使用“looplifting”,而且我找不到启用它的方法。另一方面,5-minute guide 声明对象模式:

In this mode Numba will identify loops that it can compile and compile those into functions that run in machine code, and it will run the rest of the code in the interpreter

但是有一个问题。如果我从第 5 行删除 (forceobj = True)(只留下 @jit),我会得到以下输出:

Elapsed (plain python)      = 0.88104248046875
c:\[path-to-my-folder]\test.py:6: NumbaWarning: 
Compilation is falling back to object mode WITH looplifting enabled because Function "_take1" failed type inference due to: Use of unsupported NumPy function 'numpy.ndarray' or unsupported use of the function.

File "test.py", line 8:
def _take1(arr, idxs):
    res = np.ndarray((idxs.size,), arr.dtype)
    ^

During: typing of get attribute at c:\[path-to-my-folder]\test.py (8)

File "test.py", line 8:
def _take1(arr, idxs):
    res = np.ndarray((idxs.size,), arr.dtype)
    ^

  @jit

  warnings.warn(errors.NumbaWarning(warn_msg,
C:\[path-to-my-python]\python38\lib\site-packages\numba\core\object_mode_passes.py:161: NumbaDeprecationWarning:
Fall-back from the nopython compilation path to the object mode compilation path has been detected, this is deprecated behaviour.

For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit

File "test.py", line 8:
def _take1(arr, idxs):
    res = np.ndarray((idxs.size,), arr.dtype)
    ^
  warnings.warn(errors.NumbaDeprecationWarning(msg,
Elapsed (with compilation)  = 0.4593193531036377
Elapsed (after compilation) = 0.0029413700103759766

所以现在函数的编译版本比纯 python 快得多,即使 numba 显然仍然处于对象模式。对我来说可疑的是警告文本中的 object mode WITH looplifting enabled 。看起来也可以有一个没有循环提升的对象模式。

总而言之,我有 3 个问题:

  1. 我真的无法在 nopython 模式下使用 Numba 创建新的 ndarray 吗?
  2. 为什么回退到对象模式的 @jit@jit(forceobj = True) 工作得快很多?
  3. Numba 中是否有没有looplifting 的对象模式?如果是这样,它的目的是什么?

谢谢!

Is it true that I can not create a new ndarray with Numba in nopython mode?

这是错误的。以下工作代码可以做到这一点:

import numpy as np
import numba as nb

@nb.njit('int32[:](int32)')
def example(n):
    res = np.empty(n, dtype=np.int32)
    for i in range(n):
        res[i] = i
    return res

example(4)  # output: [0, 1, 2, 3]

Why does @jit with fallback to object mode works a lot faster than @jit(forceobj = True)?

你应该像瘟疫一样避免对象模式。它通常非常低效并且通常不是很有用,除非你处理 Python 由于 GIL、引用计数、分配、dynamic-typing 和许多性能问题导致 C[=48] 计算缓慢的对象=] 慢。 documentation 状态:

Whilst the use of looplifting in object mode can enable some performance increase, getting functions to compile under no python mode is really the key to good performance.

此外,also states:

forceobj forces the function to be compiled in object mode. Since object mode is slower than nopython mode, this is mostly useful for testing purposes.

因此,挑战在于使用 njit 模式构建您的代码。在你的情况下这很简单:你只需要使用 np.emptynjit.

@njit
def _take1(arr, idxs):
    res = np.empty((idxs.size,), arr.dtype)
    lastIdx = arr.size - 1
    for i in range(idxs.size):
        idx = idxs[i]
        if idx > lastIdx:
            idx = lastIdx
        elif idx < 0:
            idx = 0
        res[i] = arr[idx]
    return res

这使代码在我的机器上更加精通:

Elapsed (plain python)      = 3.343827962875366
Elapsed (with compilation)  = 0.06859207153320312
Elapsed (after compilation) = 0.00159454345703125

请注意,您可以提供 _take1 的签名,以便 Numba 可以对其进行编译 eagerly

Is there an object mode without looplifting in Numba? And if so, what is its purpose?

对象模式的目的只是为了与pure-Python代码兼容。当 Numba 代码与 pure-Python 函数混合时,这很有用。 pure-Python 函数调用会很慢,但至少 Numba 可以在 Numba 代码中间调用它们。对象模式在优化的 Numba 代码中通常不是很有用,如果你使用它,那么你通常需要 looplifting so 来提高热循环的性能。否则,在目标情况下使用 Numba 肯定不是很有用(极少数例外包括包装代码)。