Numba 的打字错误
TypingError for Numba
我有这一段代码,使用Numba来加速处理。基本上,particle_dtype 被定义为使用 Numba 生成代码 运行。但是,会报告 TypingError,说“无法确定 的 Numba 类型”。我不知道问题出在哪里。
import numpy
from numba import njit
particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'],
'formats':[numpy.double,
numpy.double,
numpy.double,
numpy.double,
numpy.double]})
def create_n_random_particles(n, m, domain=1):
parts = numpy.zeros((n), dtype=particle_dtype)
parts['x'] = numpy.random.random(size=n) * domain
parts['y'] = numpy.random.random(size=n) * domain
parts['z'] = numpy.random.random(size=n) * domain
parts['m'] = m
parts['phi'] = 0.0
return parts
def distance(se, other):
return numpy.sqrt(numpy.square(se['x'] - other['x']) +
numpy.square(se['y'] - other['y']) +
numpy.square(se['z'] - other['z']))
parts = create_n_random_particles(10, .001, 1)
@njit
def direct_sum(particles):
for i, target in enumerate(particles):
for j in range(particles.shape[0]):
if i == j:
continue
source = particles[j]
r = distance(target, source)
# target['phi'] += source['m'] / r
target['phi'] = target['phi'] + source['m'] / r
return(target['phi'])
print(direct_sum(parts) )
我猜是因为某处使用了不受支持的函数或操作,但我找不到它。感谢您的帮助。
direct_sum
是一个 JITed 函数,不能调用 distance
因为它不是 JITed(pure-Python 函数)。您也需要在 distance
上使用装饰器 @njit
。
我有这一段代码,使用Numba来加速处理。基本上,particle_dtype 被定义为使用 Numba 生成代码 运行。但是,会报告 TypingError,说“无法确定
import numpy
from numba import njit
particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'],
'formats':[numpy.double,
numpy.double,
numpy.double,
numpy.double,
numpy.double]})
def create_n_random_particles(n, m, domain=1):
parts = numpy.zeros((n), dtype=particle_dtype)
parts['x'] = numpy.random.random(size=n) * domain
parts['y'] = numpy.random.random(size=n) * domain
parts['z'] = numpy.random.random(size=n) * domain
parts['m'] = m
parts['phi'] = 0.0
return parts
def distance(se, other):
return numpy.sqrt(numpy.square(se['x'] - other['x']) +
numpy.square(se['y'] - other['y']) +
numpy.square(se['z'] - other['z']))
parts = create_n_random_particles(10, .001, 1)
@njit
def direct_sum(particles):
for i, target in enumerate(particles):
for j in range(particles.shape[0]):
if i == j:
continue
source = particles[j]
r = distance(target, source)
# target['phi'] += source['m'] / r
target['phi'] = target['phi'] + source['m'] / r
return(target['phi'])
print(direct_sum(parts) )
我猜是因为某处使用了不受支持的函数或操作,但我找不到它。感谢您的帮助。
direct_sum
是一个 JITed 函数,不能调用 distance
因为它不是 JITed(pure-Python 函数)。您也需要在 distance
上使用装饰器 @njit
。