多维数组的 Numba 旋转矩阵
Numba Rotation Matrix with Multidimensional Arrays
我正在尝试使用 Numba 来加速某些功能,特别是执行给定三个角度的 3D 旋转的功能,如下所示:
import numpy as np
from numba import jit
@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
# Convert to radians. To ensure counter-clockwise (ccw) rotations, take
# negative of angles.
theta_x_rad = -np.radians(theta_x)
theta_y_rad = -np.radians(theta_y)
theta_z_rad = -np.radians(theta_z)
# Define rotation matrices (yaw, pitch, roll)
Rx = np.array([[1, 0,0],
[0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
[0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
])
Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
[ 0,1,0],
[-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
])
Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
[np.sin(theta_z_rad),np.cos(theta_z_rad),0],
[0,0,1]
])
# Compute total rotation matrix
R = np.dot(Rz, np.dot( Ry, Rx ))
#
return R
函数比较简单,但是当Numba调用它时,当我尝试定义Rx
时它抛出错误。看来Numba对多维数组有问题(?)。我不确定如何修改它以便 Numba 可以利用它。任何帮助将不胜感激。
问题来自整数和浮点类型值的混合。 Numba 尝试定义一种数组类型,发现 [1, 0, 0]
是一个整数列表,但整个数组是用整数列表和浮点数列表初始化的。由于整体类型不明确,类型推断很混乱并引发错误。您可以写 1.0
和 0.0
而不是 1
和 0
来解决这个问题。更一般地说,指定数组的 dtype
通常是一个好习惯,尤其是在 Numba 中,由于 类型推断 .
如果你想在第一次调用函数时避免运行时的编译错误,那么你可以精确参数类型。请注意,您可以使用 njit
而不是 nopython=True
(更短)。生成的装饰器应该是 @njit('(float64, float64, float64)')
.
我正在尝试使用 Numba 来加速某些功能,特别是执行给定三个角度的 3D 旋转的功能,如下所示:
import numpy as np
from numba import jit
@jit(nopython=True)
def rotation_matrix(theta_x, theta_y, theta_z):
# Convert to radians. To ensure counter-clockwise (ccw) rotations, take
# negative of angles.
theta_x_rad = -np.radians(theta_x)
theta_y_rad = -np.radians(theta_y)
theta_z_rad = -np.radians(theta_z)
# Define rotation matrices (yaw, pitch, roll)
Rx = np.array([[1, 0,0],
[0, np.cos(theta_x_rad),-np.sin(theta_x_rad)],
[0, np.sin(theta_x_rad),np.cos(theta_x_rad) ]
])
Ry = np.array([[ np.cos(theta_y_rad), 0,np.sin(theta_y_rad)],
[ 0,1,0],
[-np.sin(theta_y_rad), 0,np.cos(theta_y_rad)]
])
Rz = np.array([[np.cos(theta_z_rad),-np.sin(theta_z_rad),0],
[np.sin(theta_z_rad),np.cos(theta_z_rad),0],
[0,0,1]
])
# Compute total rotation matrix
R = np.dot(Rz, np.dot( Ry, Rx ))
#
return R
函数比较简单,但是当Numba调用它时,当我尝试定义Rx
时它抛出错误。看来Numba对多维数组有问题(?)。我不确定如何修改它以便 Numba 可以利用它。任何帮助将不胜感激。
问题来自整数和浮点类型值的混合。 Numba 尝试定义一种数组类型,发现 [1, 0, 0]
是一个整数列表,但整个数组是用整数列表和浮点数列表初始化的。由于整体类型不明确,类型推断很混乱并引发错误。您可以写 1.0
和 0.0
而不是 1
和 0
来解决这个问题。更一般地说,指定数组的 dtype
通常是一个好习惯,尤其是在 Numba 中,由于 类型推断 .
如果你想在第一次调用函数时避免运行时的编译错误,那么你可以精确参数类型。请注意,您可以使用 njit
而不是 nopython=True
(更短)。生成的装饰器应该是 @njit('(float64, float64, float64)')
.