Numba:不支持单元格变量
Numba : cell vars are not supported
我想使用 numba 来加速这个函数:
from numba import jit
@jit
def rownowaga_numba(u, v):
wymiar_x = len(u)
wymiar_y = len(u[1])
f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]
for i in range( wymiar_x):
for j in range (wymiar_y):
for k in range(9):
up = u[i][j]
vp = v[i][j]
udot = (up**2 + vp**2)
cu = up*cx[k] + vp*cy[k]
f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
我在哪里用这样的数据测试它:
import timeit
import math as m
u = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]
y = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]
t0 = timeit.default_timer()
for i in range (10):
f = rownowaga_pypy(u,y)
dt = timeit.default_timer() - t0
print('loop time:', dt)
我收到这个错误:
Traceback (most recent call last):
File "C:\Users\Ricevind\Desktop\PyPy\Skrypty\Rownowaga.py", line 29, in <module>
f = rownowaga_pypy(u,y)
File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 171, in _compile_for_args
return self.compile(sig)
File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 348, in compile
flags=flags, locals=self.locals)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 637, in compile_extra
return pipeline.compile_extra(func)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 356, in compile_extra
raise e
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 351, in compile_extra
bc = self.extract_bytecode(func)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 343, in extract_bytecode
bc = bytecode.ByteCode(func=self.func)
File "C:\pyzo2014a\lib\site-packages\numba\bytecode.py", line 343, in __init__
raise NotImplementedError("cell vars are not supported")
NotImplementedError: cell vars are not supported
我最感兴趣的是 "cell vars are not supported" 的含义,因为 Google returns 没有结果。
Numba 目前在列表的嵌套列表上效果不是特别好(至少从 v0.21 开始)。我相信这就是 'cell vars' 错误所指的内容,但我不是 100% 确定。下面,我将所有内容都转换为 numpy 数组,以使代码能够被 numba 优化:
import numpy as np
import numba as nb
import math
def rownowaga(u, v):
wymiar_x = len(u)
wymiar_y = len(u[1])
f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]
for i in range( wymiar_x):
for j in range (wymiar_y):
for k in range(9):
up = u[i][j]
vp = v[i][j]
udot = (up**2 + vp**2)
cu = up*cx[k] + vp*cy[k]
f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
# Pull these out so that numba treats them as constant arrays
cx = np.array([0., 1., 0., -1., 0., 1., -1., -1., 1.])
cy = np.array([0., 0., 1., 0., -1., 1., 1., -1., -1.])
w = np.array([4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36])
@nb.jit(nopython=True)
def rownowaga_numba(u, v):
wymiar_x = u.shape[0]
wymiar_y = u[1].shape[0]
f = np.zeros((9, wymiar_x, wymiar_y))
for i in xrange( wymiar_x):
for j in xrange (wymiar_y):
for k in xrange(9):
up = u[i,j]
vp = v[i,j]
udot = (up*up + vp*vp)
cu = up*cx[k] + vp*cy[k]
f[k,i,j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
现在让我们设置一些测试数组:
u = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]
y = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]
u_np = np.array(u)
y_np = np.array(y)
首先让我们验证我的 numba 代码给出的答案是否与 OP 的代码相同:
f1 = rownowaga(u, y)
f2 = rownowaga_numba(u_np, y_np)
来自 ipython 笔记本:
In [13]: np.allclose(f2, np.array(f1))
Out[13]:
True
现在让我们在笔记本电脑上计时:
In [15] %timeit f1 = rownowaga(u, y)
1 loops, best of 3: 288 ms per loop
In [16] %timeit f2 = rownowaga_numba(u_np, y_np)
1000 loops, best of 3: 973 µs per loop
因此,我们以最少的代码更改获得了 300 倍的加速。请注意,我使用的是 0.22 之前的 Numba 每晚版本:
In [16]: nb.__version__
Out[16]:
'0.21.0+137.gac9929d'
我想使用 numba 来加速这个函数:
from numba import jit
@jit
def rownowaga_numba(u, v):
wymiar_x = len(u)
wymiar_y = len(u[1])
f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]
for i in range( wymiar_x):
for j in range (wymiar_y):
for k in range(9):
up = u[i][j]
vp = v[i][j]
udot = (up**2 + vp**2)
cu = up*cx[k] + vp*cy[k]
f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
我在哪里用这样的数据测试它:
import timeit
import math as m
u = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]
y = [[m.sin(i) + m.cos(j) for j in range(40)] for i in range(1000)]
t0 = timeit.default_timer()
for i in range (10):
f = rownowaga_pypy(u,y)
dt = timeit.default_timer() - t0
print('loop time:', dt)
我收到这个错误:
Traceback (most recent call last):
File "C:\Users\Ricevind\Desktop\PyPy\Skrypty\Rownowaga.py", line 29, in <module>
f = rownowaga_pypy(u,y)
File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 171, in _compile_for_args
return self.compile(sig)
File "C:\pyzo2014a\lib\site-packages\numba\dispatcher.py", line 348, in compile
flags=flags, locals=self.locals)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 637, in compile_extra
return pipeline.compile_extra(func)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 356, in compile_extra
raise e
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 351, in compile_extra
bc = self.extract_bytecode(func)
File "C:\pyzo2014a\lib\site-packages\numba\compiler.py", line 343, in extract_bytecode
bc = bytecode.ByteCode(func=self.func)
File "C:\pyzo2014a\lib\site-packages\numba\bytecode.py", line 343, in __init__
raise NotImplementedError("cell vars are not supported")
NotImplementedError: cell vars are not supported
我最感兴趣的是 "cell vars are not supported" 的含义,因为 Google returns 没有结果。
Numba 目前在列表的嵌套列表上效果不是特别好(至少从 v0.21 开始)。我相信这就是 'cell vars' 错误所指的内容,但我不是 100% 确定。下面,我将所有内容都转换为 numpy 数组,以使代码能够被 numba 优化:
import numpy as np
import numba as nb
import math
def rownowaga(u, v):
wymiar_x = len(u)
wymiar_y = len(u[1])
f = [[[0 for j in range(wymiar_y)] for i in range(wymiar_x)] for k in range(9)]
cx = [0., 1., 0., -1., 0., 1., -1., -1., 1.]
cy = [0., 0., 1., 0., -1., 1., 1., -1., -1.]
w = [4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36]
for i in range( wymiar_x):
for j in range (wymiar_y):
for k in range(9):
up = u[i][j]
vp = v[i][j]
udot = (up**2 + vp**2)
cu = up*cx[k] + vp*cy[k]
f[k][i][j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
# Pull these out so that numba treats them as constant arrays
cx = np.array([0., 1., 0., -1., 0., 1., -1., -1., 1.])
cy = np.array([0., 0., 1., 0., -1., 1., 1., -1., -1.])
w = np.array([4./9, 1./9, 1./9, 1./9, 1./9, 1./36, 1./36, 1./36, 1./36])
@nb.jit(nopython=True)
def rownowaga_numba(u, v):
wymiar_x = u.shape[0]
wymiar_y = u[1].shape[0]
f = np.zeros((9, wymiar_x, wymiar_y))
for i in xrange( wymiar_x):
for j in xrange (wymiar_y):
for k in xrange(9):
up = u[i,j]
vp = v[i,j]
udot = (up*up + vp*vp)
cu = up*cx[k] + vp*cy[k]
f[k,i,j] = w[k] + w[k]*(3.0*cu + 4.5*cu**2 - 1.5*udot)
return f
现在让我们设置一些测试数组:
u = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]
y = [[math.sin(i) + math.cos(j) for j in range(40)] for i in range(1000)]
u_np = np.array(u)
y_np = np.array(y)
首先让我们验证我的 numba 代码给出的答案是否与 OP 的代码相同:
f1 = rownowaga(u, y)
f2 = rownowaga_numba(u_np, y_np)
来自 ipython 笔记本:
In [13]: np.allclose(f2, np.array(f1))
Out[13]:
True
现在让我们在笔记本电脑上计时:
In [15] %timeit f1 = rownowaga(u, y)
1 loops, best of 3: 288 ms per loop
In [16] %timeit f2 = rownowaga_numba(u_np, y_np)
1000 loops, best of 3: 973 µs per loop
因此,我们以最少的代码更改获得了 300 倍的加速。请注意,我使用的是 0.22 之前的 Numba 每晚版本:
In [16]: nb.__version__
Out[16]:
'0.21.0+137.gac9929d'