Numba:用数组中的值替换键的快速方法
Numba: Fast way of replacing keys with values in an array
我想用具有重复元素的大尺寸 array
将 keys
替换为 values
。我正在尝试 numba
和 numpy
映射方法。两种方式的代码如下
import numpy as np
from numba import njit, prange
array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))
keys = np.arange(50)
values = -1 * np.arange(50)
## Numba Approach
@njit(parallel=True)
def numba_replace(array, keys, values):
for i in prange(len(keys)):
for j in prange(len(array)):
if array[j] == keys[i]:
array[j] = values[i]
## numpy approach
def numpy_replace(array, keys, values):
mapp = np.arange(array.size)
mapp[keys] = values
mapped = mapp[array]
return mapped
## Performance
%%timeit
numba_replace(array, keys, values)
# 117 ms ± 969 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
numpy_replace(array, keys, values)
# 61.2 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
虽然 numpy_replace
比 numba_replace
快 2 倍,但我不喜欢使用它,因为我的数组大小非常大 (3000 x 3000 x 3000)
并且 numpy 方法创建了一个 new array
增加内存使用。有什么方法可以使 numba_replace 更快,或者有什么方法不会在处理过程中创建新数组吗?
我猜是:
array[keys] = values
在 numpy 中完成工作,而不创建任何新数组
编辑: 只是为了检查该命令是否执行与您的 numpy_replace
函数相同的操作:
mapped = numpy_replace(array, keys, values)
array[keys] = values
print(all(mapped == array)) # --> True
改进 Numba 方法(降低复杂性)
由于您只想更改相对少量的值,您可以使用集合来确定是否必须更改实际的数组元素。
此外,您可以使用 search_sorted 来获取正确的键值对。对于这个小例子,差异不是很大,但如果问题规模增加,差异会变得更大。
实施
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def numba_replace(array, keys, values):
ind_sort=np.argsort(keys)
keys_sorted=keys[ind_sort]
values_sorted=values[ind_sort]
s_keys=set(keys)
for j in prange(array.shape[0]):
if array[j] in s_keys:
ind = np.searchsorted(keys_sorted,array[j])
array[j]=values_sorted[ind]
return array
时间
import numpy as np
from numba import njit, prange
array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))
#to get proper timings do nothing here
#changing the array in-place will obviously have
#an influence on the timings, because there are no values to change in the second run
keys = np.arange(50)
values = np.arange(50)
%timeit numba_replace(array, keys, values)
# 20.1 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit numpy_replace(array, keys, values)
# 51.3 ms ± 392 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
我想用具有重复元素的大尺寸 array
将 keys
替换为 values
。我正在尝试 numba
和 numpy
映射方法。两种方式的代码如下
import numpy as np
from numba import njit, prange
array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))
keys = np.arange(50)
values = -1 * np.arange(50)
## Numba Approach
@njit(parallel=True)
def numba_replace(array, keys, values):
for i in prange(len(keys)):
for j in prange(len(array)):
if array[j] == keys[i]:
array[j] = values[i]
## numpy approach
def numpy_replace(array, keys, values):
mapp = np.arange(array.size)
mapp[keys] = values
mapped = mapp[array]
return mapped
## Performance
%%timeit
numba_replace(array, keys, values)
# 117 ms ± 969 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%%timeit
numpy_replace(array, keys, values)
# 61.2 ms ± 159 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
虽然 numpy_replace
比 numba_replace
快 2 倍,但我不喜欢使用它,因为我的数组大小非常大 (3000 x 3000 x 3000)
并且 numpy 方法创建了一个 new array
增加内存使用。有什么方法可以使 numba_replace 更快,或者有什么方法不会在处理过程中创建新数组吗?
我猜是:
array[keys] = values
在 numpy 中完成工作,而不创建任何新数组
编辑: 只是为了检查该命令是否执行与您的 numpy_replace
函数相同的操作:
mapped = numpy_replace(array, keys, values)
array[keys] = values
print(all(mapped == array)) # --> True
改进 Numba 方法(降低复杂性)
由于您只想更改相对少量的值,您可以使用集合来确定是否必须更改实际的数组元素。 此外,您可以使用 search_sorted 来获取正确的键值对。对于这个小例子,差异不是很大,但如果问题规模增加,差异会变得更大。
实施
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def numba_replace(array, keys, values):
ind_sort=np.argsort(keys)
keys_sorted=keys[ind_sort]
values_sorted=values[ind_sort]
s_keys=set(keys)
for j in prange(array.shape[0]):
if array[j] in s_keys:
ind = np.searchsorted(keys_sorted,array[j])
array[j]=values_sorted[ind]
return array
时间
import numpy as np
from numba import njit, prange
array1 = np.arange(150*150*150, dtype=int)
array2 = np.arange(150*150*150, dtype=int)
array = np.concatenate((array1, array2))
#to get proper timings do nothing here
#changing the array in-place will obviously have
#an influence on the timings, because there are no values to change in the second run
keys = np.arange(50)
values = np.arange(50)
%timeit numba_replace(array, keys, values)
# 20.1 ms ± 1.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit numpy_replace(array, keys, values)
# 51.3 ms ± 392 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)