在 numba 中删除 numpy.array 中的一行
Delete a row in numpy.array in numba
这是我第一次post来这里。我正在尝试删除 numba jitclass 中 numpy 数组中的一行。我写了下面的代码来删除任何包含 3 的行:
>>> a = np.array([[1,2,3,4],[5,6,7,8]])
>>> a
>>> array([[1, 2, 3, 4],
[5, 6, 7, 8]])
>>> i = np.where(a==3)
>>> i
>>> (array([0]), array([2]))
我无法使用 numpy.delete() 函数,因为它不受 numba 支持,并且无法将 None 类型的值分配给该行。我所能做的就是通过以下方式将 0 分配给该行:
>>> a[i[0]] = 0
>>> a
>>> array([[0, 0, 0, 0],
[5, 6, 7, 8]])
但我想完全删除该行。
任何帮助将不胜感激。
非常感谢。
我认为您需要再次将您的删除分配给变量 a,这对我有用。试试下面的代码:
import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
print(a)
i = np.where(a==3)
a=np.delete(a, i, 0) # assign it back to the variable
print(a)
欢迎来到 Stacoverflow。您可以简单地使用数组切片来 select 仅其中没有 3 的行。
下面的代码有点复杂,基本上涵盖了额外的细节,尽管你可以有一个更短的版本,删除不必要的行。键值分配是rows_final = [x for x in range(a.shape[0]) if x not in rows3]
代码:
import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8],[10,11,3,13]])
ind = np.argwhere(a==3)
rows3 = ind[0]
cols3 = ind[1]
print ("Initial Array: \n", a)
print()
print("rows, cols of a==3 : ", rows3, cols3)
rows_final = [x for x in range(a.shape[0]) if x not in rows3]
a_final = a[rows_final,:]
print()
print ("Final Rows: \n", rows_final)
print ("Final Array: \n", a_final)
输出:
Initial Array:
[[ 1 2 3 4]
[ 5 6 7 8]
[10 11 3 13]]
rows, cols of a==3 : [0 2] [2 2]
Final Rows:
[1]
Final Array:
[[5 6 7 8]]
这实际上不是一件容易的事,因为 numba 有以下限制:
- 不支持
np.delete
- 不支持
np.all
和 np.any
中的 axis
关键字
- 不支持二维数组索引(至少不支持 bool 掩码)
- 没有或阻碍直接创建具有
np.zeros(shape, dtype=np.bool)
或类似功能的 bool 掩码
但是您仍然可以采用多种方法来解决您的问题。我测试了一些,创建布尔掩码似乎是最快和最干净的方法。
@nb.njit
def delete_workaround(arr, num):
mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
mask[np.where(arr == num)[0]] = False
return arr[mask]
a = np.array([[1,2,3,4],[5,6,7,8]])
delete_workaround(a, 3)
此解决方案还具有 保留数组维度的巨大优势,即使仅返回一行或一个空数组也是如此。这对 jitclasses很重要,因为 jitclasses 严重依赖固定维度。
既然您提出要求,我将向您展示一个将数组与列表相互转换的解决方案。由于 numba 中的所有 python 方法尚不支持反射列表,因此您必须对函数的某些部分使用包装器:
@nb.njit
def delete_lrow(arr_list, num):
idx_list = []
for i in range(len(arr_list)):
if (arr_list[i] != num).all():
idx_list.append(i)
res_list = [arr_list[i] for i in idx_list]
return res_list
def wrap_list_del(arr, num):
arr_list = list(arr)
return np.array(delete_lrow(arr_list, num))
arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)
%timeit delete_workaround(arr, 3)
# 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit wrap_list_del(arr, 3)
# 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit delete_workaround(arr2, 3)
# 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit wrap_list_del(arr2, 3)
# 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此,如果您已经拥有数组(即使您还没有数组,但您的数据是一致的类型),那么坚持使用数组对于小数组来说大约快 50 倍,大约 550大数组快 1 倍。
需要记住的是:Numpy 数组用于处理数值数据! Numpy 针对处理数字数据进行了大量优化!如果数据类型 (dtype
) 是常量并且没有超级特殊的东西需要它,那么将数值数据数组转换为另一个 "format" 绝对没有用(我几乎没有遇到过这种情况) .
对于 numba 优化代码尤其如此! Numba 严重依赖于 numpy 和常量 dtypes
/shapes 等。如果你想使用 jitclasses 则更是如此。
numba 现在支持 Numpy delete(但首先是参数是数组本身和包含应删除的索引的数组)
这是我第一次post来这里。我正在尝试删除 numba jitclass 中 numpy 数组中的一行。我写了下面的代码来删除任何包含 3 的行:
>>> a = np.array([[1,2,3,4],[5,6,7,8]])
>>> a
>>> array([[1, 2, 3, 4],
[5, 6, 7, 8]])
>>> i = np.where(a==3)
>>> i
>>> (array([0]), array([2]))
我无法使用 numpy.delete() 函数,因为它不受 numba 支持,并且无法将 None 类型的值分配给该行。我所能做的就是通过以下方式将 0 分配给该行:
>>> a[i[0]] = 0
>>> a
>>> array([[0, 0, 0, 0],
[5, 6, 7, 8]])
但我想完全删除该行。
任何帮助将不胜感激。
非常感谢。
我认为您需要再次将您的删除分配给变量 a,这对我有用。试试下面的代码:
import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8]])
print(a)
i = np.where(a==3)
a=np.delete(a, i, 0) # assign it back to the variable
print(a)
欢迎来到 Stacoverflow。您可以简单地使用数组切片来 select 仅其中没有 3 的行。
下面的代码有点复杂,基本上涵盖了额外的细节,尽管你可以有一个更短的版本,删除不必要的行。键值分配是rows_final = [x for x in range(a.shape[0]) if x not in rows3]
代码:
import numpy as np
a = np.array([[1,2,3,4],[5,6,7,8],[10,11,3,13]])
ind = np.argwhere(a==3)
rows3 = ind[0]
cols3 = ind[1]
print ("Initial Array: \n", a)
print()
print("rows, cols of a==3 : ", rows3, cols3)
rows_final = [x for x in range(a.shape[0]) if x not in rows3]
a_final = a[rows_final,:]
print()
print ("Final Rows: \n", rows_final)
print ("Final Array: \n", a_final)
输出:
Initial Array:
[[ 1 2 3 4]
[ 5 6 7 8]
[10 11 3 13]]
rows, cols of a==3 : [0 2] [2 2]
Final Rows:
[1]
Final Array:
[[5 6 7 8]]
这实际上不是一件容易的事,因为 numba 有以下限制:
- 不支持
np.delete
- 不支持
np.all
和np.any
中的 - 不支持二维数组索引(至少不支持 bool 掩码)
- 没有或阻碍直接创建具有
np.zeros(shape, dtype=np.bool)
或类似功能的 bool 掩码
axis
关键字
但是您仍然可以采用多种方法来解决您的问题。我测试了一些,创建布尔掩码似乎是最快和最干净的方法。
@nb.njit
def delete_workaround(arr, num):
mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
mask[np.where(arr == num)[0]] = False
return arr[mask]
a = np.array([[1,2,3,4],[5,6,7,8]])
delete_workaround(a, 3)
此解决方案还具有 保留数组维度的巨大优势,即使仅返回一行或一个空数组也是如此。这对 jitclasses很重要,因为 jitclasses 严重依赖固定维度。
既然您提出要求,我将向您展示一个将数组与列表相互转换的解决方案。由于 numba 中的所有 python 方法尚不支持反射列表,因此您必须对函数的某些部分使用包装器:
@nb.njit
def delete_lrow(arr_list, num):
idx_list = []
for i in range(len(arr_list)):
if (arr_list[i] != num).all():
idx_list.append(i)
res_list = [arr_list[i] for i in idx_list]
return res_list
def wrap_list_del(arr, num):
arr_list = list(arr)
return np.array(delete_lrow(arr_list, num))
arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)
%timeit delete_workaround(arr, 3)
# 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit wrap_list_del(arr, 3)
# 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit delete_workaround(arr2, 3)
# 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit wrap_list_del(arr2, 3)
# 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此,如果您已经拥有数组(即使您还没有数组,但您的数据是一致的类型),那么坚持使用数组对于小数组来说大约快 50 倍,大约 550大数组快 1 倍。
需要记住的是:Numpy 数组用于处理数值数据! Numpy 针对处理数字数据进行了大量优化!如果数据类型 (dtype
) 是常量并且没有超级特殊的东西需要它,那么将数值数据数组转换为另一个 "format" 绝对没有用(我几乎没有遇到过这种情况) .
对于 numba 优化代码尤其如此! Numba 严重依赖于 numpy 和常量 dtypes
/shapes 等。如果你想使用 jitclasses 则更是如此。
numba 现在支持 Numpy delete(但首先是参数是数组本身和包含应删除的索引的数组)