包含 numpy 数组的元组排序出现 numba 错误

numba error with tuple sorting containing numpy arrays

我有一个(有效的)函数,它使用 heapq 模块构建元组的优先级队列,我想用 numba 编译它,但是我得到一个很长且不清楚的错误。它似乎归结为队列所需的元组顺序比较问题。元组具有固定格式,其中第一项是一个浮点数(我关心其顺序),然​​后是一个 numpy 数组,我需要计算但通常在 运行 时不会进行比较。这是故意的,因为对 numpy 数组的比较会产生一个不能在条件中使用的数组并引发异常。但是,我想 numba 需要至少为元组中的所有项目定义一个标量产生比较,因此 numba 错误。

我有一个非常简单的例子:

@numba.njit
def f():
    return 1 if (1, numpy.arange(3)) < (2, numpy.arange(3)) else 2
f()

numba 编译失败的地方(没有 numba 它可以工作,因为它永远不需要像原始代码那样实际比较数组)。

这是一个稍微简单但可能更清晰的示例,它显示了我实际在做什么:

from heapq import heappush
import numpy
import numba
@numba.njit
def f(n):
  heap = [(1, 0, numpy.random.rand(2, 3))]
  for unique_id in range(n):
    order = numpy.random.rand()
    data = numpy.random.rand(2, 3)
    heappush(heap, (order, unique_id, data))
  return heap[0]
f(100)

这里的order是我关心的变量在队列中的顺序,unique_id是避免这种情况的技巧当 order 相同时,比较继续 data 并抛出异常。

我试图绕过在元组中将 numpy 数组转换为列表并返回数组进行计算的问题,但是在编译时,numba 版本比解释版本慢,即使数组相当小(通常为 2x3)。如果不进行转换,我需要将代码重写为循环,我希望避免这种循环(但这是可行的)。

是否有更好的替代方法来使用 numba,希望 运行 比 python 解释器更快?

我会尝试根据您提供的最小示例进行回复。

我认为这里的问题不在于numba是否能够对元组的所有元素进行比较,而在于将这种比较的结果存储在哪里。这在尝试执行您的示例时返回的错误日志中有说明:

cannot store {i8*, i8*, i64, i64, i8*, [1 x i64], [1 x i64]} to i1*: mismatching types

基本上,您试图将一对浮点数和一对数组之间的比较结果存储到一个布尔值中,而 numba 不知道该怎么做。

如果您只对比较元组的第一个元素感兴趣,我能想到的最快的解决方法是强制只对第一个元素进行比较,例如

@numba.njit
def f():
    return 1 if (1, numpy.arange(3))[0] < (2, numpy.arange(3))[0] else 2
f()

如果这不适用于您的用例,请提供更多详细信息。

编辑

根据您提供的进一步信息,我认为解决此问题的最佳方法是避免将 numpy 数组推入堆。由于您只对堆的排序属性感兴趣,因此您可以将键推送到堆并将相应的 numpy 数组存储在单独的字典中,使用与推送到堆中的相同值作为键。

作为旁注,当您在 nopython-jitted 函数中使用标准库函数时,您正在求助于这些函数的特定 numba re-implementation 而不是“原始” python那些。可在 here.

中找到 numba 中可用 python 功能的完整列表

好的,我找到了解决问题的方法:由于将数组存储在堆元组中是导致numba错误的原因,因此将其存储在具有唯一键的单独字典中并仅存储键就足够了在堆元组中。例如,使用整数作为键:

from heapq import heappush
import numpy
import numba
@numba.njit
def f(n):
  key = 0
  array_storage = {key: numpy.random.rand(2, 3)}
  heap = [(1.0, key)]
  for _ in range(n):
    order = numpy.random.rand()
    data = numpy.random.rand(2, 3)
    key += 1
    heappush(heap, (order, key))
    array_storage[key] = data
  return heap[0]
f(100)

现在可以比较堆中的元组以产生一个布尔值,我仍然可以将数据与其元组相关联。我并不完全满意,因为它似乎是一种解决方法,但它工作得很好并且并不过分复杂。如果谁有更好的请告诉我!