如何在 numba 中创建结构化标量列表?

How to create a list of structured scalars in numba?

import numpy as np
from numba import njit

dt = np.dtype([('x', np.float64), ('y', np.float64)])

@njit
def f():
#    a = np.zeros(2, dtype=dt)         # this works
#    return a['x']
    b = np.array((0.5, 1.5), dtype=dt)    # this doesn't
#    return b['x']
f()

错误信息是

NotImplementedError: Cannot cast float64 to Record(x[type=float64;offset=0],y[type=float64;offset=8];16;False): %".69" = phi double [%".70", %"switch.0"], [%".72", %"switch.1"]

没有 @jit 它工作正常。

我真正想要实现的是创建自定义 dtype 标量列表。我尝试了以下替代方法:

更新: 到目前为止我能得到的最远点是:

dt = np.dtype([('x', np.float64), ('y', np.float64)])
@nb.njit
def f():
    a = np.array((0.5, 1.5))
    b = a.view(dt)
    return b.x
f()

array([0.5])

但它不是标量,它是一个大小为 1 的数组(带或不带 @jit)。

更新2:

Recfunctions 还没有包含在 numba 中。

from numpy.lib import recfunctions
from numba import njit
dt = np.dtype([('x', np.float64), ('y', np.float64)])
@njit
def f():
    a = np.array((1,2))
    b = recfunctions.unstructured_to_structured(a, dt)
    return b['x']
f()

Unknown attribute 'unstructured_to_structured' of type 
Module(<module 'numpy.lib.recfunctions'

显然 numba 还没有完全实现 numpy 结构化数组的功能。该错误表明它在将值从元组分配给定义的数组时遇到问题。

玩了一会儿后,我发现这行得通:

In [399]: dt = np.dtype([('x', np.float64),('y', np.float64)])                  
In [400]: @numba.njit 
     ...: def nf(vals, dt): 
     ...:     b = np.zeros((), dtype=dt) 
     ...:     b['x'][...] = vals[0] 
     ...:     b['y'][...] = vals[1] 
     ...:     return b 
     ...:                                                                       
In [401]: nf((.5,1.5),dt)                                                       
Out[401]: array((0.5, 1.5), dtype=[('x', '<f8'), ('y', '<f8')])

或者制作一维数组:

In [405]: @numba.njit 
     ...: def nf1(n, x, y , dt): 
     ...:     b = np.zeros(n, dtype=dt) 
     ...:     b['x'][...] = x 
     ...:     b['y'][...] = y 
     ...:     return b 
     ...:                                                                       
In [406]: nf1(3, np.arange(3), np.ones(3), dt)                                  
Out[406]: array([(0., 1.), (1., 1.), (2., 1.)], dtype=[('x', '<f8'), ('y', '<f8')])