Numba 在广播时弄乱了 dtype

Numba messes up dtype when broadcasting

我想通过使用小型数据类型来安全存储。但是,当我将数字添加或乘以数组时,numba 将 dtype 更改为 int64:

纯 Numpy

在:

def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

输出:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

现在有了 numba:

在:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a+1
f()

输出:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)

一种解决方案是将 a+1 替换为 a+np.ones(a.shape, dtype=a.dtype) 但我无法想象更丑陋的东西。

非常感谢您的帮助!

正如您在评论中提到的,这可能是因为 numba 的默认类型是 int64,而较小的 dtype uint8 被转换为较大的 int64

为什么不直接转换呢?

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return (a+1).astype('uint8')
f()

输出:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

这没有 a+np.ones(a.shape, dtype=a.dtype) 丑陋。 ;)

您可以使用 np.ones_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a + np.ones_like(a)
f()

输出:

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

...或np.full_like:

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a + np.full_like(a, 100)
f()

输出:

array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)

如果您愿意让您的函数接受输入,您可以解决这个问题。我使用 njit:

signature_or_function 参数重写了你的函数
@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
    a = np.ones(10, dtype=np.uint8)
    return a+x

f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)

关于 numba signatures 的一些文档。如果您定义签名,numba 将为每个唯一签名编译一个专门的函数,并尝试对未明确定义签名的任何内容使用兼容的 pre-compiled 签名。那里的签名告诉它它将 return 一个数组,如果无符号 8 位整数 ('uint8[:]') 并接受一个无符号 8 位整数值的输入。

请注意,在这种情况下,我必须让函数接受输入,因为 numba 似乎默认将整数文字(例如,a + 11)视为 int64 值,但是如果您指定函数的输入是 uint8 并且您没有做出更宽松的签名,那么当您编译并 运行 函数时,它将把您的输入视为uint8 而不是 up-convert 因为它不需要。

我想最简单的就是加两个 np.uint8:

import numpy as np
from numba import njit

@njit
def f():
    a=np.ones(10, dtype=np.uint8)
    return a + np.uint8(1)
print(f().dtype)

输出:

uint8

我发现这比更改整个数组的类型或使用 np.onesnp.full.

更优雅