Numba 中的 Numpy 聚合函数垫片、键入和 np.sort()
Numpy aggregate function shims, typing, and np.sort() in Numba
我正在 nopython
模式下使用 Numba (0.44) 和 Numpy。目前,Numba 不支持跨任意轴的 Numpy 聚合函数,它只支持在整个数组上计算这些聚合。鉴于这种情况,我决定试一试并制作一些垫片。
在代码中:
np.min(array) # This works with Numba 0.44
np.min(array, axis = 0) # This does not work with Numba 0.44 (no axis argument allowed)
这是一个 shim 示例,旨在重现 np.min(array)
:
import numpy as np
import numba
@numba.jit(nopython = True)
def npmin (X, axis = -1):
"""
Shim for broadcastable np.min().
Allows np.min(array), np.min(array, axis = 0), and np.min(array, axis = 1)
Note that the argument axis = -1 computes on the entire array.
"""
if axis == 0:
_min = np.sort(X.transpose())[:,0]
elif axis == 1:
_min = np.sort(X)[:,0]
else:
_min = np.sort(np.sort(X)[:,0])[0]
return _min
在没有 Numba 的情况下,shim 会按预期工作,并将 np.min()
的行为概括为二维数组。请注意,我使用 axis = -1
作为允许对整个数组求和的一种方式——类似于在没有 axis
参数的情况下调用 np.min(array)
的行为。
不幸的是,一旦我将 Numba 加入其中,我就会收到错误消息。这是踪迹:
Traceback (most recent call last):
File "shims.py", line 81, in <module>
_min = npmin(a)
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 348, in _compile_for_args
error_rewrite(e, 'typing')
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 315, in error_rewrite
reraise(type(e), e, None)
File "/usr/local/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sort at 0x10abd5ea0>) with argument(s) of type(s): (array(int64, 2d, F))
* parameterized
In definition 0:
All templates rejected
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sort at 0x10abd5ea0>)
[2] During: typing of call at shims.py (27)
File "shims.py", line 27:
def npmin (X, axis = -1):
<source elided>
if axis == 0:
_min = np.sort(X.transpose())[:,0]
^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new
我已经确认我正在使用的所有函数及其各自的参数在 Numba 0.44 中都受支持。当然,堆栈跟踪 说 问题出在我对 np.sort(array)
的调用上,但我怀疑这可能是打字问题,因为函数可以 return 或者标量(没有轴参数)或二维数组(有轴参数)。
也就是说,我有几个问题:
- 我的实现有问题吗?任何人都可以按照堆栈跟踪的建议查明我正在使用的不受支持的功能吗?
- 或者更确切地说,这似乎是 Numba 的错误?
- 更一般地说,Numba (0.44) 目前可以使用这些类型的垫片吗?
这是二维数组的替代垫片:
@numba.jit(nopython=True)
def npmin2(X, axis=0):
if axis == 0:
_min = np.empty(X.shape[1])
for i in range(X.shape[1]):
_min[i] = np.min(X[:,i])
elif axis == 1:
_min = np.empty(X.shape[0])
for i in range(X.shape[0]):
_min[i] = np.min(X[i,:])
return _min
尽管您必须为 axis=-1
情况找出解决方法,因为那将是 return 一个标量,而其他参数将是 return 数组,而 Numba 将无法将 "unify" 和 return 键入一致的内容。
性能,至少在我的机器上,似乎与调用等效的 np.min
大致相当,有时 np.min
更快,有时 npmin2
胜出,取决于输入数组大小和轴。
我正在 nopython
模式下使用 Numba (0.44) 和 Numpy。目前,Numba 不支持跨任意轴的 Numpy 聚合函数,它只支持在整个数组上计算这些聚合。鉴于这种情况,我决定试一试并制作一些垫片。
在代码中:
np.min(array) # This works with Numba 0.44
np.min(array, axis = 0) # This does not work with Numba 0.44 (no axis argument allowed)
这是一个 shim 示例,旨在重现 np.min(array)
:
import numpy as np
import numba
@numba.jit(nopython = True)
def npmin (X, axis = -1):
"""
Shim for broadcastable np.min().
Allows np.min(array), np.min(array, axis = 0), and np.min(array, axis = 1)
Note that the argument axis = -1 computes on the entire array.
"""
if axis == 0:
_min = np.sort(X.transpose())[:,0]
elif axis == 1:
_min = np.sort(X)[:,0]
else:
_min = np.sort(np.sort(X)[:,0])[0]
return _min
在没有 Numba 的情况下,shim 会按预期工作,并将 np.min()
的行为概括为二维数组。请注意,我使用 axis = -1
作为允许对整个数组求和的一种方式——类似于在没有 axis
参数的情况下调用 np.min(array)
的行为。
不幸的是,一旦我将 Numba 加入其中,我就会收到错误消息。这是踪迹:
Traceback (most recent call last):
File "shims.py", line 81, in <module>
_min = npmin(a)
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 348, in _compile_for_args
error_rewrite(e, 'typing')
File "/usr/local/lib/python3.7/site-packages/numba/dispatcher.py", line 315, in error_rewrite
reraise(type(e), e, None)
File "/usr/local/lib/python3.7/site-packages/numba/six.py", line 658, in reraise
raise value.with_traceback(tb)
numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function sort at 0x10abd5ea0>) with argument(s) of type(s): (array(int64, 2d, F))
* parameterized
In definition 0:
All templates rejected
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function sort at 0x10abd5ea0>)
[2] During: typing of call at shims.py (27)
File "shims.py", line 27:
def npmin (X, axis = -1):
<source elided>
if axis == 0:
_min = np.sort(X.transpose())[:,0]
^
This is not usually a problem with Numba itself but instead often caused by
the use of unsupported features or an issue in resolving types.
To see Python/NumPy features supported by the latest release of Numba visit:
http://numba.pydata.org/numba-doc/dev/reference/pysupported.html
and
http://numba.pydata.org/numba-doc/dev/reference/numpysupported.html
For more information about typing errors and how to debug them visit:
http://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
If you think your code should work with Numba, please report the error message
and traceback, along with a minimal reproducer at:
https://github.com/numba/numba/issues/new
我已经确认我正在使用的所有函数及其各自的参数在 Numba 0.44 中都受支持。当然,堆栈跟踪 说 问题出在我对 np.sort(array)
的调用上,但我怀疑这可能是打字问题,因为函数可以 return 或者标量(没有轴参数)或二维数组(有轴参数)。
也就是说,我有几个问题:
- 我的实现有问题吗?任何人都可以按照堆栈跟踪的建议查明我正在使用的不受支持的功能吗?
- 或者更确切地说,这似乎是 Numba 的错误?
- 更一般地说,Numba (0.44) 目前可以使用这些类型的垫片吗?
这是二维数组的替代垫片:
@numba.jit(nopython=True)
def npmin2(X, axis=0):
if axis == 0:
_min = np.empty(X.shape[1])
for i in range(X.shape[1]):
_min[i] = np.min(X[:,i])
elif axis == 1:
_min = np.empty(X.shape[0])
for i in range(X.shape[0]):
_min[i] = np.min(X[i,:])
return _min
尽管您必须为 axis=-1
情况找出解决方法,因为那将是 return 一个标量,而其他参数将是 return 数组,而 Numba 将无法将 "unify" 和 return 键入一致的内容。
性能,至少在我的机器上,似乎与调用等效的 np.min
大致相当,有时 np.min
更快,有时 npmin2
胜出,取决于输入数组大小和轴。