更快的 numpy isin 替代使用 numba 的字符串

Faster numpy isin alternative for strings using numba

我正在尝试在 numba 中实现 np.isin 的更快版本,这是我目前所拥有的:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def isin(a, b):
    out=np.empty(a.shape[0], dtype=nb.boolean)
    b = set(b)
    for i in nb.prange(a.shape[0]):
        if a[i] in b:
            out[i]=True
        else:
            out[i]=False
    return out

对于数字它有效,如本例所示:

a = np.array([1,2,3,4])
b = np.array([2,4])

isin(a,b)
>>> array([False,  True, False,  True])

而且比np.isin快:

a = np.random.rand(20000)
b = np.random.rand(5000)

%time isin(a,b)
CPU times: user 3.96 ms, sys: 0 ns, total: 3.96 ms
Wall time: 1.05 ms

%time np.isin(a,b)
CPU times: user 11 ms, sys: 0 ns, total: 11 ms
Wall time: 8.48 ms

但是,我想将此函数用于包含字符串的数组。问题是,每当我尝试传递一个字符串数组时,numba 抱怨它无法用这些数据解释 set() 操作。

a = np.array(['A','B','C','D'])
b = np.array(['B','D'])

isin(a,b)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<class 'set'>) found for signature:
 
 >>> set(array([unichr x 1], 1d, C))
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload of function 'set': File: numba/core/typing/setdecl.py: Line 20.
    With argument(s): '(array([unichr x 1], 1d, C))':
   No match.

During: resolving callee type: Function(<class 'set'>)
During: typing of call at /tmp/ipykernel_20582/4221597969.py (7)


File "../../../../tmp/ipykernel_20582/4221597969.py", line 7:
<source missing, REPL/exec in use?>

有没有办法,比如指定签名,让我可以直接在字符串数组上使用它?

我知道我可以为每个字符串分配一个数值,但对于大型数组,我认为这需要一段时间,并且会使整个过程比仅使用 np.isin.

有什么想法吗?

Numba 几乎不支持字符串(类似于 bytes,尽管支持稍微好一点)。 Set 和 dictionary 支持有一些严格的限制,并且相当 experimental/new。 尚不支持字符串集 关于 documentation:

Sets must be strictly homogeneous: Numba will reject any set containing objects of different types, even if the types are compatible (for example, {1, 2.5} is rejected as it contains a int and a float). The use of reference counted types, e.g. strings, in sets is unsupported.

您可以尝试使用二进制搜索作弊。不幸的是,np.searchsorted 还没有为 string-typed Numpy 数组实现(尽管 np.unique 是)。我认为您可以自己实施二分查找,但这最终会很麻烦。我不确定这最终会更快,但我认为应该是因为 O(Ns Na log Nb)) 运行 时间复杂度(Ns 字符串长度的平均大小为 b 独特物品,Na a 中的物品数量和 Nb b 中独特物品的数量)。事实上,np.isin 的 运行 时间复杂度是 O(Ns (Na+Nb) log (Na+Nb)) 如果数组大小相似,如果 NbNa 小得多,则 O(Ns Na Nb)。请注意,最好的理论 运行 时间复杂度是 AFAIK O(Ns (Na + Nb)),这要归功于具有良好哈希函数的哈希 table(Tries 也可以实现这一点,但它们实际上应该更慢,除非散列函数不是很好)。

请注意,类型化词典支持静态声明的字符串但不支持动态字符串(这是静态字符串的实验性功能)。

另一个作弊(应该有效)是 将字符串散列存储为类型字典的键 并将每个散列关联到引用 [= 中字符串位置的索引数组17=] 具有关联键的哈希值。并行循环需要散列 a 项并获取 b 中具有此散列的字符串项的位置,以便您随后可以比较字符串。更快的实现是假设 b 字符串的哈希函数是完美的并且没有冲突,因此您可以直接使用 TypedDict[np.int64, np.int64] 哈希 table。您可以在构建 b 时在运行时测试这个假设。写这样的代码有点乏味。请注意,此实现最终可能不会比 Numpy 快,因为 Numba TypedDicts 目前非常慢......但是,在具有足够内核的处理器上,这应该更快。