numba 和 numpy.expand_dims
numba and numpy.expand_dims
我正在重写我的一些函数以适合 Numba。现在我有一个函数,我在我的脚本中多次调用不同维度的输入数组。
def FormHistMatrix2(x,Whc,Lm):
if x.ndim == 1:
x = np.expand_dims(x,axis=1)
[N,Ncells] = x.shape
这是我函数的开头,Numba 抛出以下错误:
TypingError: Cannot unify array(float64, 2d, A) and array(float64, 3d, A) for 'x', defined at C:/Users/DNP_Student_3/Documents/Python Scripts/GCFuncsTests.py (332)
在这种情况下,'x' 是一个二维数组,但在其他情况下,它可以是一个一维数组。
那么 Numba 不喜欢 if 循环吗?或者这里发生了什么?
在 Numba 中,与标准 python 不同,变量在函数执行期间不能更改其类型。您应该能够将 np.expand_dims
的调用结果分配给另一个变量,它会起作用。如果有时 x
是 1d,有时它是 2d,只要在函数执行过程中所有变量的类型保持一致就可以了。
JoshAdel 所说的大体上是正确的,但在这种情况下的问题是,根据输入类型,您需要一个不同的 implementation/specialization 函数。
Numba 在这种情况下具有 @generated_jit
-decorator。
在您的情况下,您需要编写一个专门的 expand-dims 函数,该函数取决于输入数组的维度:
import numba as nb
@nb.generated_jit(nopython=True)
def nb_expander(x):
if x.ndim == 1:
return lambda x: np.expand_dims(x, axis=1)
else:
return lambda x: x
需要从您的其他函数中调用此函数:
@nb.njit
def FormHistMatrix2(x, Whc, Lm):
x = nb_expander(x)
[N, Ncells] = x.shape
这现在适用于维度 1 和维度 2 的 x
。对于 x.ndim==3
,您还需要对形状实施类似的方法。
我正在重写我的一些函数以适合 Numba。现在我有一个函数,我在我的脚本中多次调用不同维度的输入数组。
def FormHistMatrix2(x,Whc,Lm):
if x.ndim == 1:
x = np.expand_dims(x,axis=1)
[N,Ncells] = x.shape
这是我函数的开头,Numba 抛出以下错误:
TypingError: Cannot unify array(float64, 2d, A) and array(float64, 3d, A) for 'x', defined at C:/Users/DNP_Student_3/Documents/Python Scripts/GCFuncsTests.py (332)
在这种情况下,'x' 是一个二维数组,但在其他情况下,它可以是一个一维数组。 那么 Numba 不喜欢 if 循环吗?或者这里发生了什么?
在 Numba 中,与标准 python 不同,变量在函数执行期间不能更改其类型。您应该能够将 np.expand_dims
的调用结果分配给另一个变量,它会起作用。如果有时 x
是 1d,有时它是 2d,只要在函数执行过程中所有变量的类型保持一致就可以了。
JoshAdel 所说的大体上是正确的,但在这种情况下的问题是,根据输入类型,您需要一个不同的 implementation/specialization 函数。
Numba 在这种情况下具有 @generated_jit
-decorator。
在您的情况下,您需要编写一个专门的 expand-dims 函数,该函数取决于输入数组的维度:
import numba as nb
@nb.generated_jit(nopython=True)
def nb_expander(x):
if x.ndim == 1:
return lambda x: np.expand_dims(x, axis=1)
else:
return lambda x: x
需要从您的其他函数中调用此函数:
@nb.njit
def FormHistMatrix2(x, Whc, Lm):
x = nb_expander(x)
[N, Ncells] = x.shape
这现在适用于维度 1 和维度 2 的 x
。对于 x.ndim==3
,您还需要对形状实施类似的方法。