将 numba 函数拆分为项目中的单独模块以进行打包

splitting numba functions into separate modules in project for packaging

我的项目中有几个模块,每个模块都包含几个 numba 函数。我知道在第一次导入时,函数会被编译。 我现在才注意到的是,即使我只从模块中导入一个函数,似乎所有函数都会被编译,因为导入花费的时间相同。

我想为此实现一种更细粒度的方法,因为对于某些应用程序,您实际上只需要一个函数,因此编译所有函数是浪费时间。

为此,我将函数拆分为单独的模块,如下所示:

Project/
|--src/
|  |-- __init__.py
|  |-- fun1.py
|  |-- fun2.py
|  |-- fun3.py 
|  |-- fun4.py
|  |-- ...

__init__.py包括

from .fun1 import fun1
from .fun2 import fun2
...

所以它们可以像 from src import fun1 一样导入。

这似乎工作正常,但在导入级别有一些重复,例如每个函数都需要 from numba import jit,其中一些需要 from numpy import zeros 等等。

所以我的问题是这样是否可行,或者是否有更好的方法来打包许多 numba 函数。

编辑:

将所有导入语句放入 __init__.py 显然意味着一旦导入一个函数,所有函数都会被编译 - 所以根本没有任何收获。

我仍然可以导入像

这样的功能
from src.fun1 import fun1

这似乎有效。但是语法有点笨拙。

有趣的问题 - 您实际上是在问如何延迟函数的定义,直到它被显式导入。 我认为最好的方法就像你说的那样,使用 from src.fun1 import fun1 并且每个文件有一个函数。

我认为当你在同一个文件中有多个函数时实现这个可能非常棘手,所以我将问题放松为“我们如何延迟函数的定义,直到它被显式 调用(未导入)"。

简单的解决方案

一个简单的方法是将您的函数包装在一个虚拟外部函数中。

这并不完全符合我们的要求,因为后续调用 fun1 将导致重新创建内部函数和 numba.jit 装饰器,并且需要重新编译。

# main.py

# This lets us see when numba is compiling.
# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"

import fun1
print("note no numba debug output yet for fun1")
print("fun1 result is", fun1.fun1(1, 2))
print("fun1 result is", fun1.fun1(2, 1))
print("note the function was compiled twice :(")

# fun1.py

import numba

# Naively wrap fun1 in another function so it's only declared
# when the outer function is called.
def fun1(*args, **kwargs):
    @numba.jit("float32(float32, float32)", cache=False)  # No cache, for debugging
    def __fun1(a, b):
        return a + b
    return __fun1(*args, **kwargs)

使用装饰器的更高级解决方案

简单的解决方案是将您的函数包装在另一个函数中....闻起来很像装饰器....

我创建了一个装饰器(外部装饰器),它将另一个装饰器(内部装饰器)作为输入。 外部装饰器仅在第一次调用该函数时应用内部装饰器(在本例中为 numba.jit)。然后它在后续调用中重新使用内部装饰函数。

# main.py

# This lets us see when numba is compiling.
# See https://numba.pydata.org/numba-doc/dev/reference/envvars.html
import os
os.environ["NUMBA_DEBUG_FRONTEND"] = "1"

import fun2
print("note no numba debug output yet for fun2")
print("fun2 result is", fun2.fun2(3, 4))
print("fun2 result is", fun2.fun2(5, 6))
print("note the function was compiled only once :)")

# fun2.py

import numba
from functools import wraps

def delayed(internal_decorator):
    def _delayed(f):
        inner_decorated = None
        @wraps(f)
        def wrapper(*args, **kwds):
            nonlocal inner_decorated
            if inner_decorated is None:
                inner_decorated = internal_decorator(f)
            return inner_decorated(*args, **kwds)
        return wrapper
    return _delayed

@delayed(numba.jit("float32(float32, float32)", cache=False))
def fun2(a, b):
    return a * b

I know that on first import the functions get compiled.

这不一定是真的。 Eager compilation 仅当使用显式签名声明函数时才会发生。

声明带有显式签名的函数有两个后果:

  • 函数由装饰器编译(在声明或导入时)。
  • 不允许其他函数特化。

如果你不需要禁止额外的特化,你可以简单地删除函数签名,它们将在每个特化的第一次执行时被编译。

您可以通过在 numba.core.dispatcher.Dispatcher.compile() 方法中放置一个断点来检查函数何时被编译。