如何使用 Numba 将正确的数据类型传递给 class?

How to pass proper datatype to class using Numba?

我正在尝试使用 numba 让我的代码执行得更快。但是,代码抛出以下错误:

This error may have been caused by the following argument(s):
- argument 0: Unsupported array dtype: object
- argument 1: Unsupported array dtype: object

我有一个class写的是某种方式:

spec = [
    ('train_x', float64[:,:]),
    ('train_y', float64[:]),
    ('test_x', float64[:,:]),
    ('test_y', float64[:]),
]

@jitclass(spec)
class num_features:
    def __init__(self, train_x,  test_x, train_y, test_y):
        self.train_x, self.train_y = train_x, train_y
        self.test_x, self.test_y = test_x, test_y
        self.X_train, self.Y_train = [] , []
        self.X_test, self.Y_test = [] , []

    @property
    def extract_stats(self, matrix):
    ...

我打电话给class喜欢

obj = num_features(train_x.to_numpy(), test_x.to_numpy(), train_y, test_y)

train_x 和 test_x 是 pandas 数据帧。

您的代码中有几处出错了。首先,您不能在 numba-class 中使用常规 python 列表,所有属性都需要键入。您需要将属性指定为 ListTypes 并为它们分配一个它们将包含的类型,例如 float64.

其次,您看到的实际错误是因为您试图将 train_xtest_x 作为包含非 float64 数据的 numpy 数组传递。这就是错误“Unsupported array dtype: object”告诉您的内容:参数 0 和参数 1 的数组是对象数组,或 python 对象的数组。

将它们转换为 numpy 数组时,传递一个 dtype.

此外,不要对元组赋值感兴趣,numba 已经足够挑剔了,只需每行一个。

from numba import jitclass, float64, typed, types

spec = [
    ('train_x', float64[:,:]),
    ('train_y', float64[:]),
    ('test_x', float64[:,:]),
    ('test_y', float64[:]),
    ('X_train', types.ListType(types.float64)),
    ('Y_train', types.ListType(types.float64)),
    ('X_test', types.ListType(types.float64)),
    ('Y_test', types.ListType(types.float64))
]

@jitclass(spec)
class num_features:
    def __init__(self, train_x,  test_x, train_y, test_y):
        self.train_x = train_x
        self.train_y = train_y
        self.test_x = test_x
        self.test_y = test_y
        self.X_train = typed.List.empty_list(types.float64)
        self.Y_train = typed.List.empty_list(types.float64)
        self.X_test = typed.List.empty_list(types.float64)
        self.Y_test = typed.List.empty_list(types.float64)

    @property
    def extract_stats(self, matrix):
    ...

现在要真正调用 class,您需要传入 float64 数组。您可以使用:

obj = num_features(train_x.to_numpy(np.float64),
                   test_x.to_numpy(np.float64),
                   train_y.astype(np.float64),
                   test_y.astype(np.float64))