比较 numba 编译函数中的字符串

Compare strings in numba-compiled function

我正在寻找在使用 numba jit(无 python 模式,python 3)编译的 python 函数中比较字符串的最佳方法。

用例如下:

import numba as nb

@nb.jit(nopython = True, cache = True)
def foo(a, t = 'default'):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        ...

但是返回如下错误:

Invalid usage of == with parameters (str, const('awesome'))

我尝试使用字节但没有成功。

谢谢!


Maurice 指出了问题 但我正在查看原生 python 而不是 numba 支持的 numpy 子集。

对于较新的 numba 版本(0.41.0 及更高版本)

Numba(自版本 0.41.0 起)支持 str in nopython mode 并且问题中所写的代码将“正常工作”。但是,对于您的示例,比较字符串 比您的操作慢 很多,因此如果您想在 numba 函数中使用字符串,请确保开销是值得的。

import numba as nb

@nb.njit
def foo_string(a, t):
    if t == 'awesome':
        return(a**2)
    elif t == 'default':
        return(a**3)
    else:
        return a

@nb.njit
def foo_int(a, t):
    if t == 1:
        return(a**2)
    elif t == 0:
        return(a**3)
    else:
        return a

assert foo_string(100, 'default') == foo_int(100, 0)
%timeit foo_string(100, 'default')
# 2.82 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit foo_int(100, 0)
# 213 ns ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

在您的例子中,使用字符串的代码速度要慢 10 倍以上。

因为你的函数做的不多,所以在 Python 而不是 numba:

中进行字符串比较会更好更快
def foo_string2(a, t):
    if t == 'awesome':
        sec = 1
    elif t == 'default':
        sec = 0
    else:
        sec = -1
    return foo_int(a, sec)

assert foo_string2(100, 'default') == foo_string(100, 'default')
%timeit foo_string2(100, 'default')
# 323 ns ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

这仍然比纯整数版本慢一点,但比在 numba 函数中使用字符串快将近 10 倍。

但是,如果您在 numba 函数中进行大量数值运算,则字符串比较开销将无关紧要。但是简单地把 numba.njit 放在一个函数上,尤其是如果它不做很多数组操作或数字运算,不会让它自动变快!

对于旧的 numba 版本(0.41.0 之前):

Numba 在 nopython 模式下不支持字符串。

来自documentation

2.6.2. Built-in types

2.6.2.1. int, bool [...]

2.6.2.2. float, complex [...]

2.6.2.3. tuple [...]

2.6.2.4. list [...]

2.6.2.5. set [...]

2.6.2.7. bytes, bytearray, memoryview

The bytearray type and, on Python 3, the bytes type support indexing, iteration and retrieving the len().

[...]

因此根本不支持字符串,字节也不支持相等性检查。

但是您可以传入 bytes 并迭代它们。这使得编写自己的比较函数成为可能:

import numba as nb

@nb.njit
def bytes_equal(a, b):
    if len(a) != len(b):
        return False
    for char1, char2 in zip(a, b):
        if char1 != char2:
            return False
    return True

不幸的是,下一个问题是 numba 不能“降低”字节,所以你不能直接在函数中硬编码字节。但是字节基本上只是整数,并且 bytes_equal 函数适用于 numba 支持的所有类型,这些类型具有长度并且可以迭代。所以你可以简单地将它们存储为列表:

import numba as nb

@nb.njit
def foo(a, t):
    if bytes_equal(t, [97, 119, 101, 115, 111, 109, 101]):
        return a**2
    elif bytes_equal(t, [100, 101, 102, 97, 117, 108, 116]):
        return a**3
    else:
        return a

或作为全局数组(感谢@chrisb - 查看评论):

import numba as nb
import numpy as np

AWESOME = np.frombuffer(b'awesome', dtype='uint8')
DEFAULT = np.frombuffer(b'default', dtype='uint8')

@nb.njit
def foo(a, t):
    if bytes_equal(t, AWESOME):
        return a**2
    elif bytes_equal(t, DEFAULT):
        return a**3
    else:
        return a

两者都可以正常工作:

>>> foo(10, b'default')
1000
>>> foo(10, b'awesome')
100
>>> foo(10, b'awe')
10

但是,您不能将字节数组指定为默认值,因此您需要明确提供 t 变量。这样做也感觉很老套。

我的意见:只需在普通函数中执行 if t == ... 检查并在 if 中调用专门的 numba 函数。 Python 中的字符串比较非常快,只需将 math/array-intensive 内容包装在 numba 函数中即可:

import numba as nb

@nb.njit
def awesome_func(a):
    return a**2

@nb.njit
def default_func(a):
    return a**3

@nb.njit
def other_func(a):
    return a

def foo(a, t='default'):
    if t == 'awesome':
        return awesome_func(a)
    elif t == 'default':
        return default_func(a)
    else:
        return other_func(a)

但请确保您确实需要 numba 来实现这些功能。有时正常 Python/NumPy 就足够快了。只需分析 numba 解决方案和 Python/NumPy 解决方案,看看 numba 是否使其速度显着加快。 :)

我建议接受@MSeifert 的回答,但作为此类问题的另一种选择,请考虑使用 enum.

在 python 中,字符串通常用作一种枚举,您 numba 内置了对枚举的支持,因此可以直接使用它们。

import enum

class FooOptions(enum.Enum):
    AWESOME = 1
    DEFAULT = 2

import numba

@numba.njit
def foo(a, t=FooOptions.DEFAULT):
    if t == FooOptions.AWESOME:
        return a**2
    elif t == FooOptions.DEFAULT:
        return a**2
    else:
        return a

foo(10, FooOptions.AWESOME)
Out[5]: 100