使用 Numba 在每一行上应用多个函数

Applying multiple functions on each row using Numba

我有一个很大的二维 NumPy 数组,假设有 500 万行和 10 列。我想根据使用 Numba @jitclass 实现的一些有状态逻辑来构建更多列。假设要创建 50 个这样的新列。这个想法是在 Numba @jit 函数中遍历所有 10 列的行,并且对于每一行,应用我的 50 "filters" 中的每一行来生成一个新的单元格。所以:

 Source1..Source10    Derived1..Derived50
[array of 10 inputs] [array of 50 outputs]
     ... 5 million rows like this ...

问题是,我无法将 "filters" 的列表或元组传递给 @jit(nopython=True) 函数,因为它们不是同质的:

@numba.jit(nopython=True)
def calc_derived(source, derived, filters):
    for srcidx, src in enumerate(source):
        for filtidx, filt in enumerate(filters): # doesn't work
            derived[srcidx,filtidx] = filt.transform(src)

以上内容不起作用,因为 filters 是一堆不同的 class。据我所知,即使让它们派生自一个共同的基础 class 也不够好。

我有可能交换循环的顺序,并在 @jit 函数之外循环 50 个过滤器,但这意味着整个源数据集将被加载 50 次而不是一次,非常浪费。

您有解决 Numba "homogenous lists only" 要求的技术吗?

要获得同类列表,您可以构建所有过滤器的 transform 函数的列表。在这种情况下,所有列表元素的类型都是 method.

# filters = list of filters
transforms = [x.transform for x in filters]

然后将 transforms 传递给 calc_derived() 而不是 filters

编辑: 在我的系统上,看起来 numba 会接受这个,但前提是 nopython=False

您最初询问是否使用一个循环遍历行的函数来执行此操作,并将过滤器列表应用于每一行。这种方法的一个挑战是 numba 需要知道或能够推断出每个函数的 input/output 类型。在这种情况下,我不知道有什么方法可以满足 numba 的要求(这并不是说 none 存在)。如果有办法做到这一点,它可能是一个更好的解决方案(我想知道它是什么)。

另一种方法是将循环遍历行的代码移动到过滤器本身。因为过滤器是 numba 函数,所以这应该保持速度。应用过滤器的函数将不再使用 numba;它会简单地遍历过滤器列表。但是,由于过滤器的数量相对于数据矩阵的大小而言很小,希望这不会对速度产生太大的影响。由于此函数不再使用 numba,因此 'heterogeneous list' 问题将不再是问题。

这种方法在我测试时有效(nopython 模式很好)。在测试用例中,作为 numba 函数实现的过滤器比作为 class 方法实现的过滤器快 10-18 倍(即使 classes 被实现为 numba jitclasses;不确定发生了什么那里)。为了获得一点模块化,可以将过滤器构造为闭包,以便可以使用不同的参数定义类似的过滤器。

例如,这里有计算幂和的过滤器。给定一个矩阵 x,过滤器对 x 的列进行运算,为每一行提供一个输出。它 returns 一个向量 v,其中 v[i] = sum(x[i, :] ** power)

# filter constructor
def sumpow(power):

    @numba.jit(nopython=True)
    def run_filter(x):
        (nrows, ncols) = x.shape
        result = np.zeros(nrows)
        for i in range(nrows):
            for j in range(ncols):
                result[i] += x[i,j] ** power
        return result

    return run_filter

# define filters
sum1 = sumpow(1) # sum of elements
sum2 = sumpow(2) # sum of elements squared

# apply a single filter
v = sum2(x)

应用多个过滤器的函数如下所示。每个过滤器的输出堆叠成输出的一列。

def apply_filters(x, filters):

    result = np.empty((x.shape[0], len(filters)))

    for (i, f) in enumerate(filters):
        result[:, i] = f(x)

    return result


y = apply_filters(x, [sum1, sum2])

计时结果

  • 数据矩阵:从标准正态分布中抽取的随机条目,float64,500 万行 x 10 列。所有方法都使用相同的矩阵进行测试。
  • 过滤器:sum2 以上过滤器,在列表中重复 20 次:[sum2, sum2, ...]
  • 使用 IPython 的 %timeit 函数计时,最好的 3 次运行
  • 所有方法的数值输出一致
  • Numba function filters (as shown above): 2.25s
  • Numba jitclass filters: 28.3s
  • Pure NumPy (using vectorized ops, no loops): 8.64s

我想 Numba 相对于 NumPy 可能会获得更复杂的过滤器。