Numba 函数无法从 numba 修饰的生成器函数附加到列表
Numba funcion failing to append to a list from a numba decorated generator function
Numba 无法在 numba 修饰的生成器中执行 __next__()
调用。错误显示 Unknown attribute '__next__' of type UniTuple(float64 x 4)
完整的错误输出是
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute '__next__' of type UniTuple(float64 x 4) generator(func=<function random_walk at 0x7fbe39806488>, args=(int64, int64, float64, float64, int64), has_finalizer=True)
File "rw_nb.py", line 47:
def random_walk_simulation(initial_position = 0, acceleration = 0,
<source elided>
data = []
data.append(rw.__next__())
^
[1] During: typing of get attribute at /home/igor/rw_nb.py (47)
File "rw_nb.py", line 47:
def random_walk_simulation(initial_position = 0, acceleration = 0,
<source elided>
data = []
data.append(rw.__next__())
^
MWE源码如下所示
import random
import numba
import numpy as np
@numba.njit
def random_walk(s_0, a_0, pa, pb, seed=None):
"""Initial position (often 0), acceleration, 0 < pa < pb < 1"""
if seed is not None:
random.seed(seed)
# Time, x-position, Velocity, Acceleration
t, x, v, a = 0, s_0, 0, a_0
yield (t, x, v, a)
while True:
# Roll the dices
rnd = random.random()
if rnd <= pa:
# Increase acceleration
a += .005
elif rnd <= pa+pb:
# Reduce acceleration
a -= .005
# Lets avoid too much acceleration
#lower, upper = -0.2, 0.2
a = -0.2 if a < -0.2 else 0.2 if a > 0.2 else a
# How much time has passed, since last update?
dt = random.random()
v += dt*a
x += dt*v
t += dt
yield (t, x, v, a)
@numba.njit
def random_walk_simulation(initial_position = 0, acceleration = 0,
prob_increase=5e-3, prob_decrease=5e-3,
max_distance=1e5, simul_time=1e3,
seed=None):
rw = random_walk(initial_position, acceleration,
prob_increase, prob_decrease, seed)
# Runs the first iteraction
data = []
data.append(rw.__next__())
# While there is simulation time or not too far away
while (data[-1][0] < simul_time) and (abs(data[-1][1]) < max_distance):
data.append(rw.__next__())
return np.array(data)
def main():
experiment = random_walk_simulation(seed=0)
print(experiment.shape)
if __name__ == '__main__':
main()
如果从 random_walk_simulation
函数中删除 @numba.njit
,代码将完美运行。
Numba 中的 运行 循环辅助函数 (random_walk_simulation
) 可以做什么?
要在 Numba 中调用生成器的下一项,而不是使用 data.append(rw.__next__())
,您可以这样做:
data.append(next(rw))
Numba 无法在 numba 修饰的生成器中执行 __next__()
调用。错误显示 Unknown attribute '__next__' of type UniTuple(float64 x 4)
完整的错误输出是
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute '__next__' of type UniTuple(float64 x 4) generator(func=<function random_walk at 0x7fbe39806488>, args=(int64, int64, float64, float64, int64), has_finalizer=True)
File "rw_nb.py", line 47:
def random_walk_simulation(initial_position = 0, acceleration = 0,
<source elided>
data = []
data.append(rw.__next__())
^
[1] During: typing of get attribute at /home/igor/rw_nb.py (47)
File "rw_nb.py", line 47:
def random_walk_simulation(initial_position = 0, acceleration = 0,
<source elided>
data = []
data.append(rw.__next__())
^
MWE源码如下所示
import random
import numba
import numpy as np
@numba.njit
def random_walk(s_0, a_0, pa, pb, seed=None):
"""Initial position (often 0), acceleration, 0 < pa < pb < 1"""
if seed is not None:
random.seed(seed)
# Time, x-position, Velocity, Acceleration
t, x, v, a = 0, s_0, 0, a_0
yield (t, x, v, a)
while True:
# Roll the dices
rnd = random.random()
if rnd <= pa:
# Increase acceleration
a += .005
elif rnd <= pa+pb:
# Reduce acceleration
a -= .005
# Lets avoid too much acceleration
#lower, upper = -0.2, 0.2
a = -0.2 if a < -0.2 else 0.2 if a > 0.2 else a
# How much time has passed, since last update?
dt = random.random()
v += dt*a
x += dt*v
t += dt
yield (t, x, v, a)
@numba.njit
def random_walk_simulation(initial_position = 0, acceleration = 0,
prob_increase=5e-3, prob_decrease=5e-3,
max_distance=1e5, simul_time=1e3,
seed=None):
rw = random_walk(initial_position, acceleration,
prob_increase, prob_decrease, seed)
# Runs the first iteraction
data = []
data.append(rw.__next__())
# While there is simulation time or not too far away
while (data[-1][0] < simul_time) and (abs(data[-1][1]) < max_distance):
data.append(rw.__next__())
return np.array(data)
def main():
experiment = random_walk_simulation(seed=0)
print(experiment.shape)
if __name__ == '__main__':
main()
如果从 random_walk_simulation
函数中删除 @numba.njit
,代码将完美运行。
Numba 中的 运行 循环辅助函数 (random_walk_simulation
) 可以做什么?
要在 Numba 中调用生成器的下一项,而不是使用 data.append(rw.__next__())
,您可以这样做:
data.append(next(rw))