Numba try: if array.shape[1] - error: tuple index out of range. works without numba, doesn't work with @njit(fastmath=True, nogil=True, cache=True)
Numba try: if array.shape[1] - error: tuple index out of range. works without numba, doesn't work with @njit(fastmath=True, nogil=True, cache=True)
Numba 0.53.1,Python3.7.9,Windows10 64 位
这个 doctest 工作正常:
import numpy as np
def example_numba_tri(yp):
"""
>>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
array([[0.1, 0.3, 0.6],
[0.5, 0.1, 0.4],
[0.2, 0.7, 0.1]])
"""
try:
if yp.shape[1] == 3:
pass
except:
yp = yp.reshape(int(len(yp) / 3), -1, order='F')
return yp
只需添加@njit(fastmath=True, nogil=True, cache=True)
:
from numba import njit
import numpy as np
@njit(fastmath=True, nogil=True, cache=True)
def example_numba_tri(yp):
"""
>>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
array([[0.1, 0.3, 0.6],
[0.5, 0.1, 0.4],
[0.2, 0.7, 0.1]])
"""
try:
if yp.shape[1] == 3:
pass
except:
yp = yp.reshape(int(len(yp) / 3), -1, order='F')
return yp
并得到一个错误:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Internal error at <numba.core.typeinfer.StaticGetItemConstraint object at 0x0000020CB55A7608>.
tuple index out of range
During: typing of static-get-item at C:/U1/main.py (1675)
Enable logging at debug level for details.
File "main.py", line 1675:
def example_numba_tri(yp):
<source elided>
try:
if yp.shape[1] == 3:
^
如何解决它以及为什么会发生这种情况?或者这是一个错误?
我读了 https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#pysupported-exception-handling ,但似乎我做了那里写的所有事情。
更新:
- https://github.com/numba/numba/issues/6872
- 这对我也很有用https://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
- 分析和重构帮助我仅将大多数 cpu 密集部分转换为 numba
- numba.pydata.org/numba-doc/dev/reference/numpysupported.html 好像
不支持
order
(order='F'
)
请考虑以另一种方式重写您的代码,因为看起来您的 Numba 代码检查输入数据的类型,因此块 if yp.shape[1] == 3:
在编译阶段被检查,这就是为什么 [=12 不处理它的原因=]
请尝试下面的代码,它与您的代码相同,但是没有 order='F'
不想以任何方式使用 Numba。
from numba import njit
import numpy as np
@njit(fastmath=True, nogil=True, cache=True)
def example_numba_tri(yp):
return yp.reshape(int(len(yp) / 3), -1)
def wrapper_example_numba_tri(yp):
if len(yp.shape) > 1:
if yp.shape[1] == 3:
return yp
return example_numba_tri(yp)
if name == 'main':
x = np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1])
wrapper_example_numba_tri(x)
Numba 0.53.1,Python3.7.9,Windows10 64 位
这个 doctest 工作正常:
import numpy as np
def example_numba_tri(yp):
"""
>>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
array([[0.1, 0.3, 0.6],
[0.5, 0.1, 0.4],
[0.2, 0.7, 0.1]])
"""
try:
if yp.shape[1] == 3:
pass
except:
yp = yp.reshape(int(len(yp) / 3), -1, order='F')
return yp
只需添加@njit(fastmath=True, nogil=True, cache=True)
:
from numba import njit
import numpy as np
@njit(fastmath=True, nogil=True, cache=True)
def example_numba_tri(yp):
"""
>>> example_numba_tri(np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1]))
array([[0.1, 0.3, 0.6],
[0.5, 0.1, 0.4],
[0.2, 0.7, 0.1]])
"""
try:
if yp.shape[1] == 3:
pass
except:
yp = yp.reshape(int(len(yp) / 3), -1, order='F')
return yp
并得到一个错误:
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Internal error at <numba.core.typeinfer.StaticGetItemConstraint object at 0x0000020CB55A7608>.
tuple index out of range
During: typing of static-get-item at C:/U1/main.py (1675)
Enable logging at debug level for details.
File "main.py", line 1675:
def example_numba_tri(yp):
<source elided>
try:
if yp.shape[1] == 3:
^
如何解决它以及为什么会发生这种情况?或者这是一个错误? 我读了 https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#pysupported-exception-handling ,但似乎我做了那里写的所有事情。
更新:
- https://github.com/numba/numba/issues/6872
- 这对我也很有用https://numba.pydata.org/numba-doc/latest/user/troubleshoot.html#my-code-doesn-t-compile
- 分析和重构帮助我仅将大多数 cpu 密集部分转换为 numba
- numba.pydata.org/numba-doc/dev/reference/numpysupported.html 好像
不支持
order
(order='F'
)
请考虑以另一种方式重写您的代码,因为看起来您的 Numba 代码检查输入数据的类型,因此块 if yp.shape[1] == 3:
在编译阶段被检查,这就是为什么 [=12 不处理它的原因=]
请尝试下面的代码,它与您的代码相同,但是没有 order='F'
不想以任何方式使用 Numba。
from numba import njit
import numpy as np
@njit(fastmath=True, nogil=True, cache=True)
def example_numba_tri(yp):
return yp.reshape(int(len(yp) / 3), -1)
def wrapper_example_numba_tri(yp):
if len(yp.shape) > 1:
if yp.shape[1] == 3:
return yp
return example_numba_tri(yp)
if name == 'main':
x = np.array([0.1, 0.5, 0.2, 0.3, 0.1, 0.7, 0.6, 0.4, 0.1])
wrapper_example_numba_tri(x)