Python 上 Numpy 与 Numba 的均值和标准差

Numpy's mean and standard deviation with Numba on Python

我正在尝试在一个函数中使用 Numpy 的均值和标准差函数,它们似乎与 Numba 不兼容,尽管 Numba documentation states them as compatible.

我的代码如下:

import numpy as np
import numba


a = [1, 2, 3, 4, 5, 6]

# @numba.jit(nopython=True, parallel=True)
def nmeanstd(a, n):
    b = []; c = []
    for i in range(n):
        b.append(np.mean(a))
        c.append(np.std(a))
    
    return b, c

mean, std = nmeanstd(a, 10)

查看 meanstd 时的输出是预期的:

mean
Out[31]: [3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5]

std
Out[32]: 
[1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933,
 1.707825127659933]

但我不知道为什么,当我取消注释 @numba.jit 函数时,会出现以下消息:

TypingError: No implementation of function Function(<function mean at 0x11a0e6e50>) found for signature:
 
mean(reflected list(int64)<iv=None>)
 
There are 2 candidate implementations:
      - Of which 2 did not match due to:
      Overload of function 'mean': File: numba/core/typing/npydecl.py: Line 378.
        With argument(s): '(reflected list(int64)<iv=None>)':
       No match.

During: resolving callee type: Function(<function mean at 0x11a0e6e50>)

如果我评论计算平均值的行,std 也是如此。怎么了?我虽然他们会 运行 numba 正确。您知道使用 Numba 计算均值和标准差的任何方法吗?

错误消息表明 Numba 不知道如何计算 listmean。如果首先将输入列表转换为 numpy 数组,则您的代码可以正常工作(使用 @jit):

mean, std = nmeanstd(np.array(a), 10)

文档显示 NumPy 数组直接与 Numba 一起工作并且非常高效。

如果您将 a 转换为 NumPy 数组,则代码有效。

mean, std = nmeanstd(np.array(a), 10)