Numba 在广播时弄乱了 dtype
Numba messes up dtype when broadcasting
我想通过使用小型数据类型来安全存储。但是,当我将数字添加或乘以数组时,numba 将 dtype 更改为 int64:
纯 Numpy
在:
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
现在有了 numba:
在:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)
一种解决方案是将 a+1 替换为 a+np.ones(a.shape, dtype=a.dtype) 但我无法想象更丑陋的东西。
非常感谢您的帮助!
正如您在评论中提到的,这可能是因为 numba 的默认类型是 int64
,而较小的 dtype uint8
被转换为较大的 int64
。
为什么不直接转换呢?
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return (a+1).astype('uint8')
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
这没有 a+np.ones(a.shape, dtype=a.dtype)
丑陋。 ;)
您可以使用 np.ones_like
:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.ones_like(a)
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
...或np.full_like
:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.full_like(a, 100)
f()
输出:
array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)
如果您愿意让您的函数接受输入,您可以解决这个问题。我使用 njit
:
的 signature_or_function
参数重写了你的函数
@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
a = np.ones(10, dtype=np.uint8)
return a+x
f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
关于 numba signatures 的一些文档。如果您定义签名,numba 将为每个唯一签名编译一个专门的函数,并尝试对未明确定义签名的任何内容使用兼容的 pre-compiled 签名。那里的签名告诉它它将 return 一个数组,如果无符号 8 位整数 ('uint8[:]'
) 并接受一个无符号 8 位整数值的输入。
请注意,在这种情况下,我必须让函数接受输入,因为 numba 似乎默认将整数文字(例如,a + 1
的 1
)视为 int64
值,但是如果您指定函数的输入是 uint8
并且您没有做出更宽松的签名,那么当您编译并 运行 函数时,它将把您的输入视为uint8
而不是 up-convert 因为它不需要。
我想最简单的就是加两个 np.uint8
:
import numpy as np
from numba import njit
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.uint8(1)
print(f().dtype)
输出:
uint8
我发现这比更改整个数组的类型或使用 np.ones
或 np.full
.
更优雅
我想通过使用小型数据类型来安全存储。但是,当我将数字添加或乘以数组时,numba 将 dtype 更改为 int64:
纯 Numpy
在:
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
现在有了 numba:
在:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a+1
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int64)
一种解决方案是将 a+1 替换为 a+np.ones(a.shape, dtype=a.dtype) 但我无法想象更丑陋的东西。
非常感谢您的帮助!
正如您在评论中提到的,这可能是因为 numba 的默认类型是 int64
,而较小的 dtype uint8
被转换为较大的 int64
。
为什么不直接转换呢?
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return (a+1).astype('uint8')
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
这没有 a+np.ones(a.shape, dtype=a.dtype)
丑陋。 ;)
您可以使用 np.ones_like
:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.ones_like(a)
f()
输出:
array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
...或np.full_like
:
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.full_like(a, 100)
f()
输出:
array([101, 101, 101, 101, 101, 101, 101, 101, 101, 101], dtype=uint8)
如果您愿意让您的函数接受输入,您可以解决这个问题。我使用 njit
:
signature_or_function
参数重写了你的函数
@numba.njit(signature_or_function='uint8[:](uint8)')
def f(x):
a = np.ones(10, dtype=np.uint8)
return a+x
f(1)
# array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=uint8)
关于 numba signatures 的一些文档。如果您定义签名,numba 将为每个唯一签名编译一个专门的函数,并尝试对未明确定义签名的任何内容使用兼容的 pre-compiled 签名。那里的签名告诉它它将 return 一个数组,如果无符号 8 位整数 ('uint8[:]'
) 并接受一个无符号 8 位整数值的输入。
请注意,在这种情况下,我必须让函数接受输入,因为 numba 似乎默认将整数文字(例如,a + 1
的 1
)视为 int64
值,但是如果您指定函数的输入是 uint8
并且您没有做出更宽松的签名,那么当您编译并 运行 函数时,它将把您的输入视为uint8
而不是 up-convert 因为它不需要。
我想最简单的就是加两个 np.uint8
:
import numpy as np
from numba import njit
@njit
def f():
a=np.ones(10, dtype=np.uint8)
return a + np.uint8(1)
print(f().dtype)
输出:
uint8
我发现这比更改整个数组的类型或使用 np.ones
或 np.full
.