Numba 的打字错误

TypingError for Numba

我有这一段代码,使用Numba来加速处理。基本上,particle_dtype 被定义为使用 Numba 生成代码 运行。但是,会报告 TypingError,说“无法确定 的 Numba 类型”。我不知道问题出在哪里。

import numpy
from numba import njit

particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]}) 


def create_n_random_particles(n, m, domain=1):
    parts = numpy.zeros((n), dtype=particle_dtype)
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts


def distance(se, other):
    return numpy.sqrt(numpy.square(se['x'] - other['x']) + 
                      numpy.square(se['y'] - other['y']) + 
                      numpy.square(se['z'] - other['z']))


parts = create_n_random_particles(10, .001, 1)


@njit
def direct_sum(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            # target['phi'] += source['m'] / r
            target['phi'] = target['phi'] + source['m'] / r
            return(target['phi'])
            
print(direct_sum(parts) ) 

我猜是因为某处使用了不受支持的函数或操作,但我找不到它。感谢您的帮助。

direct_sum 是一个 JITed 函数,不能调用 distance 因为它不是 JITed(pure-Python 函数)。您也需要在 distance 上使用装饰器 @njit