numba:按行乘法数组

numba: multiply arrays rowwise

我有 numpy 数组形状 (2,5) 和 (2,),我想将它们相乘 rowvise

a = np.array([[3,5,6,9,10],[4,7,8,11,12]])
b = np.array([-1,2])

来自 numpy: multiply arrays rowwise 我知道这适用于 numpy:

a * b[:,None] 

给出正确的输出

array([[ -3,  -5,  -6,  -9, -10],
       [  8,  14,  16,  22,  24]])

但是对于 numba 它不再起作用了,我收到了一堆错误消息。

代码:

import numpy as np
from numba import njit

@njit()
def fct(a,b):
    c = a * b[:,None]
    return c

a = np.array([[3,5,6,9,10],[4,7,8,11,12]])
b = np.array([-1,2])
A = fct(a, b)
print(A)

我把这段代码放在一个名为 numba_questionA.py 的文件中。 运行 它给出了错误信息:

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(int32, 1d, C), Tuple(slice<a:b>, none))
 
There are 22 candidate implementations:
  - Of which 20 did not match due to:
  Overload of function 'getitem': File: <numerous>: Line N/A.
    With argument(s): '(array(int32, 1d, C), Tuple(slice<a:b>, none))':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'GetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 162.
    With argument(s): '(array(int32, 1d, C), Tuple(slice<a:b>, none))':
   Rejected as the implementation raised a specific error:
     TypeError: unsupported array index type none in Tuple(slice<a:b>, none)
  raised from numba_questionA.py

During: typing of intrinsic-call at numba_questionA.py
During: typing of static-get-item at numba_questionA.py

File "numba_questionA.py", line 6:
def fct(a,b):
    c = a * b[:,None]
    ^

Numba 说它不能使用 None 作为数组索引,所以你可以替换

b[:, None]

来自

b.reshape(-1, 1)

但是,对于 a * b[:,None].

这样的表达式,Numba 可能不会比 Numpy 快

但是,如果您的数组确实很大,您可以利用 Numba 的并行化:

@nb.njit(parallel=True)
def fct(a, b):
    c = np.empty_like(a)
    for i in nb.prange(a.shape[1]):
        c[:, i] = a[:, i] * b
    return c

使用 guvectorize 也是一种半自动广播数据的好方法。一个好处是您可以针对不同的目标进行编译(cpuparallelcuda)。

对于像您的示例这样的小型数组,并行化可能只会引入开销。

@nb.guvectorize(["void(int32[:], int32[:], int32[:])"], 
             "(n), ()->(n)", nopython=True, target="cpu")
def fct(a, b, out):
    out[:] = a * b

A = fct(a, b)

最近的 Numba 版本也可以自动推断数据类型,如果您也愿意提供输出数组,并且只针对 cpu 目标进行编译,所以:

@nb.guvectorize("(n),()->(n)", nopython=True, target="cpu")
def fct(a, b, out):
    out[:] = a * b

A = np.empty_like(a)
fct(a, b, A)