对于大型数组,是否有比 np.isin 和 np.where 更快的方法?

Is there a faster method than np.isin and np.where for large arrays?

我有一个 1xN 数组 A 和一个 2xM 数组 B。我想制作两个新的 1xN 数组

无论我使用什么方法,都需要非常快,因为它会被调用很多次。我可以用这个做第一部分:Is there method faster than np.isin for large array?

但我很难找到完成第二部分的好方法。到目前为止,这是我得到的(调整上面 post 中的代码):

import numpy as np
import numba as nb

@nb.jit(parallel=True)
def isinvals(arr, vals):
    n = len(arr)
    result = np.full(n, False)
    result_vals = np.full(n, np.nan)
    set_vals = set(vals[0,:])
    list_vals = list(vals[0,:])
    for i in nb.prange(n):
        if arr[i] in set_vals:
            ind = list_vals.index(arr[i]) ## THIS LINE IS WAY TOO SLOW
            result[i] = True
            result_vals[i] = vals[1,ind]
    return result, result_vals


N = int(1e5)
M = int(20e3)
num_arr = 100e3
num_vals = 20e3
num_types = 6
arr = np.random.randint(0, num_arr, N)
vals_col1 = np.random.randint(0, num_vals, M)
vals_col2 = np.random.randint(0, num_types, M)
vals = np.array([vals_col1, vals_col2])

%timeit result, result_vals = isinvals(arr,vals)
46.4 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

我在上面标记的行 (list_vals.index(arr[i])) 是缓慢的部分。如果我不使用它,我可以制作一个超快的版本:

@nb.jit(parallel=True)
def isinvals_cheating(arr, vals):
    n = len(arr)
    result = np.full(n, False)
    result_vals = np.full(n, np.nan)
    set_vals = set(vals[0,:])
    list_vals = list(vals[0,:])
    for i in nb.prange(n):
        if arr[i] in set_vals:
            ind = 0 ## TEMPORARILY SETTING TO 0 TO INDICATE SPEED DIFFERENCE
            result[i] = True
            result_vals[i] = vals[1,ind]
    return result, result_vals

%timeit result, result_vals = isinvals_cheating(arr,vals)
1.13 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

即那条线让它慢了 40 倍。

有什么想法吗?我也试过使用 np.where() 但它更慢。

假设 OP 的解决方案给出了预期的结果,因为该问题对于 vals[0, idx] 中的 non-unique 值具有不同的对应值 vals[1, idx] 似乎不明确。查找 table 更快,但需要 len(arr) 额外的 space.

@nb.njit  # tested with numba 0.55.1
def isin_nb(arr, vals):
    lookup = np.empty(len(arr), np.float32)
    lookup.fill(np.nan)
    lookup[vals[0, ::-1]] = vals[1, ::-1]
    res_val = lookup[arr]
    return ~np.isnan(res_val), res_val

配合问题中使用的示例数据

res, res_val = isin_nb(arr, vals)
# %timeit 1000 loops, best of 5: 294 µs per loop

断言结果相等

np.testing.assert_equal(res, result)
np.testing.assert_equal(res_val, result_vals)