在 Numba 中关闭列表反射

Turn off list reflection in Numba

我正在尝试使用 Numba 加速我的代码。我传递给函数的参数之一是可变的列表列表。当我尝试更改其中一个子列表时,出现此错误:

Failed in nopython mode pipeline (step: nopython mode backend) cannot reflect element of reflected container: reflected list(reflected list(int64))

我实际上并不关心将我对本机列表所做的更改反映到原始 Python 列表中。我该如何告诉 Numba 不要反映这些变化?关于 Numba 中的列表反射,文档非常模糊。

谢谢,

直接引用自the docs

In nopython mode, Numba does not operate on Python objects. list are compiled into an internal representation. Any list arguments must be converted into this representation on the way in to nopython mode and their contained elements must be restored in the original Python objects via a process called reflection.

Reflection is required to maintain the same semantics as found in regular Python code. However, the reflection process can be expensive for large lists and it is not supported for lists that contain reflected data types. Users cannot use list-of-list as an argument because of this limitation.

您最好的选择是给出一个形状为 len(ll) x max(len(x) for x in ll) 的 2D numpy 数组,ll 是列表的列表。我自己使用类似这样的东西来实现这一点,然后将 arr, lengths 传递给 njit 编译函数:

def make_2D_array(lis):
    """Funciton to get 2D array from a list of lists
    """
    n = len(lis)
    lengths = np.array([len(x) for x in lis])
    max_len = np.max(lengths)
    arr = np.zeros((n, max_len))

    for i in range(n):
        arr[i, :lengths[i]] = lis[i]
    return arr, lengths

HTH.

如果将列表参数的列表传递给 numba,则应使用 numpy 数组而不是原始 Python 列表。由于不支持的列表功能,Numba 引发反射错误。您可以比较以下两个示例:

这个出现同样的错误:

TypeError: Failed in nopython mode pipeline (step: nopython mode backend)
cannot reflect element of reflected container: reflected list(reflected list(int64))

import numba

list_of_list = [[1, 2], [34, 100]]


@numba.njit()
def test(list_of_list):
    if 1 in list_of_list[0]:
        return 'haha'

test(list_of_list)

流畅运行版本是;

from numba import njit
import numpy as np


@njit
def test():
    if 1 in set(np_list_of_list[0]):
        return 'haha'


if __name__ == '__main__':
    list_of_list = [[1, 2], [34, 100]]
    np_list_of_list = np.array(list_of_list)
    print(test())