numba 函数不会编译
numba function will not compile
所以我在 numba 中有一个函数由于某种原因没有编译(计算 ROC 曲线下的面积)。我也不确定我将如何调试该函数,因为当我在 numba 装饰函数中设置断点时,我的调试器不工作。
函数如下:
@nb.njit()
def auc_numba(fcst, obs):
L = obs.size
i_ord = fcst.argsort()
sumV = 0.
sumV2 = 0.
sumW = 0.
sumW2 = 0.
n = 0
m = 0
i = 0
while True:
nn = mm = 0
while True:
j = i_ord[i]
if obs[j]:
mm += 1
else:
nn += 1
if i == L - 1:
break
jp1 = i_ord[i + 1]
if fcst[j] != fcst[jp1]:
break
i += 1
sumW += nn * (m + mm / 2.0)
sumW2 += nn * (m + mm / 2.0) * (m + mm / 2.0)
sumV += mm * (n + nn / 2.0)
sumV2 += mm * (n + nn / 2.0) * (n + nn / 2.0)
n += nn
m += mm
i += 1
if i >= L:
break
theta = sumV / (m * n)
v = sumV2 / ((m - 1) * n * n) - sumV * sumV / (m * (m - 1) * n * n)
w = sumW2 / ((n - 1) * m * m) - sumW * sumW / (n * (n - 1) * m * m)
sd_auc = np.sqrt(v / m + w / n)
return np.array([theta, sd_auc])
我的想法是我实现的 while 循环有问题。类型有可能是错误的,因此 break 没有被激活,函数永远 运行ning。
下面是一些要测试的示例数据:
obs = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1])
fcst = np.array([0.7083333, 0.5416667, 0.875, 0.5833333, 0.2083333, 0.8333333, 0.1666667, 0.9583333, 0.625, 0.1666667, 0.5, 1.0, 0.6666667, 0.2083333, 0.875, 0.75, 0.625, 0.3333333, 0.8333333, 0.2083333, 0.125, 0.0, 0.875, 0.8333333, 0.125, 0.5416667, 0.75])
当我 运行 在没有装饰器的情况下我得到 [0.89488636 0.06561209] 这是正确的值。
所以我想我是否能得到一些帮助来理解为什么它没有编译,也许还有一些关于如何在 numba 中调试的提示?
双 while True
循环发生了一些奇怪的事情。无论出于何种原因(我不明白),如果您在顶部创建两个变量 x
和 y
然后:
x = 1
y = 0
while True:
nn = mm = 0
while x > y:
并保持其他一切不变,代码有效。我将向 Numba 跟踪器提交一个问题,因为这对我来说似乎是一个错误。
更新: numba问题可以找到here
所以我在 numba 中有一个函数由于某种原因没有编译(计算 ROC 曲线下的面积)。我也不确定我将如何调试该函数,因为当我在 numba 装饰函数中设置断点时,我的调试器不工作。
函数如下:
@nb.njit()
def auc_numba(fcst, obs):
L = obs.size
i_ord = fcst.argsort()
sumV = 0.
sumV2 = 0.
sumW = 0.
sumW2 = 0.
n = 0
m = 0
i = 0
while True:
nn = mm = 0
while True:
j = i_ord[i]
if obs[j]:
mm += 1
else:
nn += 1
if i == L - 1:
break
jp1 = i_ord[i + 1]
if fcst[j] != fcst[jp1]:
break
i += 1
sumW += nn * (m + mm / 2.0)
sumW2 += nn * (m + mm / 2.0) * (m + mm / 2.0)
sumV += mm * (n + nn / 2.0)
sumV2 += mm * (n + nn / 2.0) * (n + nn / 2.0)
n += nn
m += mm
i += 1
if i >= L:
break
theta = sumV / (m * n)
v = sumV2 / ((m - 1) * n * n) - sumV * sumV / (m * (m - 1) * n * n)
w = sumW2 / ((n - 1) * m * m) - sumW * sumW / (n * (n - 1) * m * m)
sd_auc = np.sqrt(v / m + w / n)
return np.array([theta, sd_auc])
我的想法是我实现的 while 循环有问题。类型有可能是错误的,因此 break 没有被激活,函数永远 运行ning。
下面是一些要测试的示例数据:
obs = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1])
fcst = np.array([0.7083333, 0.5416667, 0.875, 0.5833333, 0.2083333, 0.8333333, 0.1666667, 0.9583333, 0.625, 0.1666667, 0.5, 1.0, 0.6666667, 0.2083333, 0.875, 0.75, 0.625, 0.3333333, 0.8333333, 0.2083333, 0.125, 0.0, 0.875, 0.8333333, 0.125, 0.5416667, 0.75])
当我 运行 在没有装饰器的情况下我得到 [0.89488636 0.06561209] 这是正确的值。
所以我想我是否能得到一些帮助来理解为什么它没有编译,也许还有一些关于如何在 numba 中调试的提示?
双 while True
循环发生了一些奇怪的事情。无论出于何种原因(我不明白),如果您在顶部创建两个变量 x
和 y
然后:
x = 1
y = 0
while True:
nn = mm = 0
while x > y:
并保持其他一切不变,代码有效。我将向 Numba 跟踪器提交一个问题,因为这对我来说似乎是一个错误。
更新: numba问题可以找到here