Python: Numpy / Numba 比较数组

Python: Numpy / Numba comparing arrays

我提供了一个没有 numba 装饰器的最小工作示例。我知道 numba 给我一个 a != b 的错误,其中 ab 是数组。知道如何让它与 numba 一起工作吗?

我还注意到,numba 可以处理展平数组,即 a.flatten() != b.flatten(). 不幸的是,我不想比较第 1 列的最后一个元素与第 2 列的第一个元素。我假设,有一种计算步幅和从平面数组中删除元素的方法,但我认为它既不快,也不可读,也不可维护。

array2d = np.array([[1, 0, 1],
                    [1, 1, 0],
                    [0, 0, 1],
                    [2, 3, 5]])

#@numba.jit(nopython=True)
def TOY_compute_changes(array2d):
    array2d = np.vstack([[False, False, False], array2d[:-1] != array2d[1:]])
    return array2d

TOY_compute_changes(array2d)
array([[False, False, False],
       [False,  True,  True],
       [ True,  True,  True],
       [ True,  True,  True]])

如果我没猜错,这应该可行:

a = np.array([
    [0, 1, 0],
    [0, 1, 0],
    [0, 1, 0],
])

b = np.array([
    [1, 0, 1],
    [0, 1, 0],
    [0, 1, 0],
])

@numba.jit(nopython=True)
def not_eq(a, b):
    return np.logical_not(a == b)

print(not_eq(a, b))

输出:

[[ True  True  True]
 [False False False]
 [False False False]]

字符串示例:

a = np.array([['a', 'b', 'c'], ['x', 'y', 'z']])
b = np.array([['x', 'y', 'z'], ['x', 'y', 'z']])
print(not_eq(a, b))

输出:

[[ True  True  True]
 [False False False]]