多维数组的 Numba 旋转矩阵

Numba Rotation Matrix with Multidimensional Arrays

我正在尝试使用 Numba 来加速某些功能,特别是执行给定三个角度的 3D 旋转的功能,如下所示:

import numpy as np
from numba import jit

@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
    # Convert to radians. To ensure counter-clockwise (ccw) rotations, take
    # negative of angles.
    theta_x_rad = -np.radians(theta_x)
    theta_y_rad = -np.radians(theta_y)
    theta_z_rad = -np.radians(theta_z)
    # Define rotation matrices (yaw, pitch, roll)
    Rx = np.array([[1, 0,0],
                    [0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
                    [0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
                    ])


    Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
                    [ 0,1,0],
                    [-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
                    ])


    Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
                    [np.sin(theta_z_rad),np.cos(theta_z_rad),0],
                    [0,0,1]
                    ])


    # Compute total rotation matrix
    R  = np.dot(Rz, np.dot( Ry, Rx ))
    #
    return R

函数比较简单,但是当Numba调用它时,当我尝试定义Rx时它抛出错误。看来Numba对多维数组有问题(?)。我不确定如何修改它以便 Numba 可以利用它。任何帮助将不胜感激。

问题来自整数和浮点类型值的混合。 Numba 尝试定义一种数组类型,发现 [1, 0, 0] 是一个整数列表,但整个数组是用整数列表和浮点数列表初始化的。由于整体类型不明确,类型推断很混乱并引发错误。您可以写 1.00.0 而不是 10 来解决这个问题。更一般地说,指定数组的 dtype 通常是一个好习惯,尤其是在 Numba 中,由于 类型推断 .

如果你想在第一次调用函数时避免运行时的编译错误,那么你可以精确参数类型。请注意,您可以使用 njit 而不是 nopython=True(更短)。生成的装饰器应该是 @njit('(float64, float64, float64)').