Numba 命名元组签名
Numba namedtuple signature
我正在尝试为 Numba 中的 namedtuple 指定 return 类型,但我无法这样做。有人可以帮忙吗?考虑以下最小代码:
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum)
return out
我收到错误
No conversion from NT(float64 x 2) to NT(float64, float64) for 'return_value.7', defined at None
File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
<source elided>
out = NT(sum=arraysum,sum2=arraysum)
return out
^
During: typing of assignment at numbanamedtuple.py (10)
File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
<source elided>
out = NT(sum=arraysum,sum2=arraysum)
return out
问题是“过度优化”的 numba 编译器(错误)。向元组中添加不同类型的变量,以告诉编译器使用异构元组(内部class)。
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2','dummy'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64,nb.int64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum,dummy=1)
return out
更新:
已测试:
- 数巴 0.51.2/Windows
- Numba 0.48.0/Google colab - Linux Ubuntu 18.04.5 LTS
改用NamedUniTuple
。它是同质命名元组的 numba 规范类型。
我正在尝试为 Numba 中的 namedtuple 指定 return 类型,但我无法这样做。有人可以帮忙吗?考虑以下最小代码:
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum)
return out
我收到错误
No conversion from NT(float64 x 2) to NT(float64, float64) for 'return_value.7', defined at None
File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
<source elided>
out = NT(sum=arraysum,sum2=arraysum)
return out
^
During: typing of assignment at numbanamedtuple.py (10)
File "numbanamedtuple.py", line 10:
def arrsum_njit(nn,xx):
<source elided>
out = NT(sum=arraysum,sum2=arraysum)
return out
问题是“过度优化”的 numba 编译器(错误)。向元组中添加不同类型的变量,以告诉编译器使用异构元组(内部class)。
import numba as nb
from collections import namedtuple
NT = namedtuple('NT',['sum','sum2','dummy'])
@nb.njit((nb.types.NamedTuple([nb.float64,nb.float64,nb.int64],NT))(nb.int64,nb.float64[:,:]),fastmath=True)
def arrsum_njit(nn,xx):
arraysum = 0.0
out = NT(sum=arraysum,sum2=arraysum,dummy=1)
return out
更新: 已测试:
- 数巴 0.51.2/Windows
- Numba 0.48.0/Google colab - Linux Ubuntu 18.04.5 LTS
改用NamedUniTuple
。它是同质命名元组的 numba 规范类型。