numpy.hstack 替代 numba.njit

numpy.hstack alternative for numba.njit

各位程序员大家好!

我想使用这段代码,但 np.hstack 函数似乎与 numba.njit 装饰器不兼容:

import numpy as np
import numba

@numba.njit
def main():
    J_1 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_2 = np.array([[-85.33333333, 34.13333333, 34.13333333, 17.06666667], [34.13333333, -34.13333333, 0., 0.], [34.13333333, 0., -34.13333333, 0.], [17.06666667, 0., 0., -870.4]])
    J_3 = np.array([[85.33333333, -34.13333333, -34.13333333, -17.06666667], [-34.13333333, 34.13333333, -0., -0.], [-34.13333333, -0., 34.13333333, -0.], [-17.06666667, -0., -0., 870.4]])
    J_4 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_old = [[J_1, J_2], [J_3, J_4]]
    J_stack = np.hstack(J_old[0])
    for row in J_old[1:]:
        col = np.hstack(row)
        J = np.vstack((J_stack, col))

    print(J)

if __name__ == '__main__':
    main()

输出:

C:\Users\Artur\Anaconda\python.exe C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py
Traceback (most recent call last):
  File "C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py", line 19, in <module>
    main()
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 401, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\dispatcher.py", line 344, in error_rewrite
    reraise(type(e), e, None)
  File "C:\Users\Artur\Anaconda\lib\site-packages\numba\core\utils.py", line 80, in reraise
    raise value.with_traceback(tb)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<function hstack at 0x000001981A60B558>) with argument(s) of type(s): (list(array(float64, 2d, C)))
 * parameterized
In definition 0:
    TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
    raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
In definition 1:
    TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))
    raised from C:\Users\Artur\Anaconda\lib\site-packages\numba\core\typing\npydecl.py:779
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: resolving callee type: Function(<function hstack at 0x000001981A60B558>)
[2] During: typing of call at C:/Users/Artur/Desktop/RL_framework/help_functions/test2.py (11)


File "test2.py", line 11:
def main():
    <source elided>
    J_old = [[J_1, J_2], [J_3, J_4]]
    J_stack = np.hstack(J_old[0])
    ^


Process finished with exit code 1

最初是这个片段:

J_old = [[J_1, J_2], [J_3, J_4]]
J_stack = np.hstack(J_old[0])
for row in J_old[1:]:
    col = np.hstack(row)
    J = np.vstack((J_stack, col))

J = np.bmat([[J_1, J_2], [J_3, J_4]]) 的替代品,后者也不适用于 numba.njit 装饰器。

np.hstack 是 numba supported numpy features 之一,错误消息清楚地说明了其他内容。作为一个简单的解决方案,您可以在分配四个块后使用以下一个衬垫来构造 J(在 numba 0.48.0 上测试):

J = np.vstack((np.hstack((J_1, J_2)),np.hstack((J_3, J_4))))

这给出的结果等同于 np.bmat 的输出。

希望这对您有所帮助。

来自错误信息

TypeError: np.hstack(): expecting a non-empty tuple of arrays, got list(array(float64, 2d, C))

我们看到了问题:hstack 的 Numba 版本需要一个数组元组,而你给了它一个数组列表。 (hstack 的 NumPy 版本更宽容,可以让你使用列表。)

这可以通过在 J_old:

中简单地使用元组而不是列表来解决
J_old = [(J_1, J_2), (J_3, J_4)]

更普遍的是,在 Numba 中尽可能总是使用元组,因为列表支持存在但相当不完整,而且许多函数对列表不满意,即使在技术上可能(尽管不是惯用的)使用列表作为NumPy 的参数。

(当然,Yacola 的解决方案更像是一个改进——我只是想指出,这是 Numba 工作所需的最小变化。)