对于大型数组,是否有比 np.isin 和 np.where 更快的方法?
Is there a faster method than np.isin and np.where for large arrays?
我有一个 1xN 数组 A 和一个 2xM 数组 B。我想制作两个新的 1xN 数组
- 一个布尔值,用于检查 B 的第一列是否在 A 中
- 另一个条目 i 是 B[1,i] 如果 B[0,i] 在 A 中,并且 np.nan 否则
无论我使用什么方法,都需要非常快,因为它会被调用很多次。我可以用这个做第一部分:Is there method faster than np.isin for large array?
但我很难找到完成第二部分的好方法。到目前为止,这是我得到的(调整上面 post 中的代码):
import numpy as np
import numba as nb
@nb.jit(parallel=True)
def isinvals(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = list_vals.index(arr[i]) ## THIS LINE IS WAY TOO SLOW
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
N = int(1e5)
M = int(20e3)
num_arr = 100e3
num_vals = 20e3
num_types = 6
arr = np.random.randint(0, num_arr, N)
vals_col1 = np.random.randint(0, num_vals, M)
vals_col2 = np.random.randint(0, num_types, M)
vals = np.array([vals_col1, vals_col2])
%timeit result, result_vals = isinvals(arr,vals)
46.4 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
我在上面标记的行 (list_vals.index(arr[i])
) 是缓慢的部分。如果我不使用它,我可以制作一个超快的版本:
@nb.jit(parallel=True)
def isinvals_cheating(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = 0 ## TEMPORARILY SETTING TO 0 TO INDICATE SPEED DIFFERENCE
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
%timeit result, result_vals = isinvals_cheating(arr,vals)
1.13 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
即那条线让它慢了 40 倍。
有什么想法吗?我也试过使用 np.where() 但它更慢。
假设 OP 的解决方案给出了预期的结果,因为该问题对于 vals[0, idx]
中的 non-unique 值具有不同的对应值 vals[1, idx]
似乎不明确。查找 table 更快,但需要 len(arr)
额外的 space.
@nb.njit # tested with numba 0.55.1
def isin_nb(arr, vals):
lookup = np.empty(len(arr), np.float32)
lookup.fill(np.nan)
lookup[vals[0, ::-1]] = vals[1, ::-1]
res_val = lookup[arr]
return ~np.isnan(res_val), res_val
配合问题中使用的示例数据
res, res_val = isin_nb(arr, vals)
# %timeit 1000 loops, best of 5: 294 µs per loop
断言结果相等
np.testing.assert_equal(res, result)
np.testing.assert_equal(res_val, result_vals)
我有一个 1xN 数组 A 和一个 2xM 数组 B。我想制作两个新的 1xN 数组
- 一个布尔值,用于检查 B 的第一列是否在 A 中
- 另一个条目 i 是 B[1,i] 如果 B[0,i] 在 A 中,并且 np.nan 否则
无论我使用什么方法,都需要非常快,因为它会被调用很多次。我可以用这个做第一部分:Is there method faster than np.isin for large array?
但我很难找到完成第二部分的好方法。到目前为止,这是我得到的(调整上面 post 中的代码):
import numpy as np
import numba as nb
@nb.jit(parallel=True)
def isinvals(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = list_vals.index(arr[i]) ## THIS LINE IS WAY TOO SLOW
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
N = int(1e5)
M = int(20e3)
num_arr = 100e3
num_vals = 20e3
num_types = 6
arr = np.random.randint(0, num_arr, N)
vals_col1 = np.random.randint(0, num_vals, M)
vals_col2 = np.random.randint(0, num_types, M)
vals = np.array([vals_col1, vals_col2])
%timeit result, result_vals = isinvals(arr,vals)
46.4 ms ± 3.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
我在上面标记的行 (list_vals.index(arr[i])
) 是缓慢的部分。如果我不使用它,我可以制作一个超快的版本:
@nb.jit(parallel=True)
def isinvals_cheating(arr, vals):
n = len(arr)
result = np.full(n, False)
result_vals = np.full(n, np.nan)
set_vals = set(vals[0,:])
list_vals = list(vals[0,:])
for i in nb.prange(n):
if arr[i] in set_vals:
ind = 0 ## TEMPORARILY SETTING TO 0 TO INDICATE SPEED DIFFERENCE
result[i] = True
result_vals[i] = vals[1,ind]
return result, result_vals
%timeit result, result_vals = isinvals_cheating(arr,vals)
1.13 ms ± 59.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
即那条线让它慢了 40 倍。
有什么想法吗?我也试过使用 np.where() 但它更慢。
假设 OP 的解决方案给出了预期的结果,因为该问题对于 vals[0, idx]
中的 non-unique 值具有不同的对应值 vals[1, idx]
似乎不明确。查找 table 更快,但需要 len(arr)
额外的 space.
@nb.njit # tested with numba 0.55.1
def isin_nb(arr, vals):
lookup = np.empty(len(arr), np.float32)
lookup.fill(np.nan)
lookup[vals[0, ::-1]] = vals[1, ::-1]
res_val = lookup[arr]
return ~np.isnan(res_val), res_val
配合问题中使用的示例数据
res, res_val = isin_nb(arr, vals)
# %timeit 1000 loops, best of 5: 294 µs per loop
断言结果相等
np.testing.assert_equal(res, result)
np.testing.assert_equal(res_val, result_vals)