numba.core.errors.TypingError: while using np.random.randint()

numba.core.errors.TypingError: while using np.random.randint()

如何将 np.random.randint 与 numba 一起使用,因为这会引发非常大的错误,https://hastebin.com/kodixazewo.sql

from numba import jit
import numpy as np
@jit(nopython=True)
def foo():
    a = np.random.randint(16, size=(3,3))
    return a
foo()

有关 nopython 变量的更多详细信息,请参见 here

from numba import jit
import numpy as np
import warnings

warnings.filterwarnings("ignore")  # suppress NumbaWarning - remove and read for more info
@jit(nopython=False)   # I guess we need the Python interpreter to randomize with more than 2 parameters in np.random.randint()        
def foo():
    a = np.random.randint(16, size=(3,3))
    return a
foo()

您可以使用 np.ndindex 循环遍历您想要的输出大小并为每个元素单独调用 np.random.randint

确保输出数据类型足以支持 randint 调用的整数范围。

from numba import njit
import numpy as np

@njit
def foo(size=(3,3)):
    
    out = np.empty(size, dtype=np.uint16)
        
    for idx in np.ndindex(size): 
        out[idx] = np.random.randint(16)
        
    return out

这使得它适用于任意形状:

foo(size=(2,2,2))

结果:

array([[[ 8,  7],
        [15,  2]],

       [[ 4, 13],
        [ 5, 11]]], dtype=uint16)