numba 的有趣行为 - 使用 argmax() 的矢量化函数
Funny behavior with numba - guvectorized functions using argmax()
考虑以下脚本:
from numba import guvectorize, u1, i8
import numpy as np
@guvectorize([(u1[:],i8)], '(n)->()')
def f(x, res):
res = x.argmax()
x = np.array([1,2,3],dtype=np.uint8)
print(f(x))
print(x.argmax())
print(f(x))
当 运行 它时,我得到以下信息:
4382569440205035030
2
2
为什么会这样?有什么办法可以解决吗?
Python 没有引用,所以 res = ...
实际上并没有分配给输出参数,而是重新绑定名称 res
。我相信 res 指向未初始化的内存,这就是为什么你的第一个 运行 给出了一个看似随机的值。
Numba 使用切片语法 ([:]
) 来解决这个问题,它会改变 res - 您还需要将类型声明为数组。一个工作函数是:
@guvectorize([(u1[:], i8[:])], '(n)->()')
def f(x, res):
res[:] = x.argmax()
考虑以下脚本:
from numba import guvectorize, u1, i8
import numpy as np
@guvectorize([(u1[:],i8)], '(n)->()')
def f(x, res):
res = x.argmax()
x = np.array([1,2,3],dtype=np.uint8)
print(f(x))
print(x.argmax())
print(f(x))
当 运行 它时,我得到以下信息:
4382569440205035030
2
2
为什么会这样?有什么办法可以解决吗?
Python 没有引用,所以 res = ...
实际上并没有分配给输出参数,而是重新绑定名称 res
。我相信 res 指向未初始化的内存,这就是为什么你的第一个 运行 给出了一个看似随机的值。
Numba 使用切片语法 ([:]
) 来解决这个问题,它会改变 res - 您还需要将类型声明为数组。一个工作函数是:
@guvectorize([(u1[:], i8[:])], '(n)->()')
def f(x, res):
res[:] = x.argmax()