我怎样才能做相当于将数组附加到 numba 中的列表的操作?
How can I do the equivalent of appending an array to a list in numba?
我认为我的代码无法正常工作,因为我有一个数组列表。是否有不同的方法可以将 final_list
数组列表构造为矩阵,以便它被 numba
接受?
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
@nb.jit(nopython=True)
def logi(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
for i in range(N_SPLITS):
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
plt.savefig('bifig.pdf')
File "logi.py", line 18, in <module>
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
File "/usr/local/lib/python2.7/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
numba.errors.TypingError: Caused By:
Traceback (most recent call last):
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 235, in run
stage()
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 449, in stage_nopython_frontend
self.locals)
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 805, in type_inference_stage
infer.propagate()
File "/usr/local/lib/python2.7/site-packages/numba/typeinfer.py", line 767, in propagate
raise errors[0]
TypingError: Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
Failed at nopython (nopython frontend)
Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
您的代码有很多问题导致 numba jit-compiler 出现问题:
带有参数的 np.sort
无效,也没有在二维数组上使用它
(参见:numpy 支持
特征)
x
从浮点数变为数组。 Numba 要求整个函数的类型一致性
下面是一个在 nopython
模式下编译并产生相同结果的 numba 函数。基本上我预先分配存储阵列,因为大小是预先知道的,然后按列排序。不幸的是 numba
没有很好的排序实现,所以你没有得到很大的加速。您可能还可以进行其他性能调整更改。另请注意,在绘图部分的每个循环中调用 logi
然后提取单个值是没有意义的。只需计算一次数组,然后选择您需要的值。
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
%matplotlib inline
def logi_orig(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
@nb.jit(nopython=True)
def logi_nb(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
out = np.empty_like(final_list)
for n in range(r.shape[0]):
out[:,n] = np.sort(final_list[:,n])
return out
def logi(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
y_orig = logi_orig(0.5, r)
y = logi(0.5, r)
y_nb = logi_nb(0.5, r)
print np.allclose(y, y_orig)
print np.allclose(y_nb, y_orig)
for i in range(N_SPLITS):
plt.plot(r, y[i], c='k', lw=0.1)
OSX (2014 MBP) 使用 Numba v0.34.0 的时间:
%timeit logi_orig(0.5, r)
%timeit logi(0.5, r)
%timeit logi_nb(0.5, r)
10 loops, best of 3: 171 ms per loop
10 loops, best of 3: 168 ms per loop
10 loops, best of 3: 77 ms per loop
我认为我的代码无法正常工作,因为我有一个数组列表。是否有不同的方法可以将 final_list
数组列表构造为矩阵,以便它被 numba
接受?
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
@nb.jit(nopython=True)
def logi(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
for i in range(N_SPLITS):
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
plt.savefig('bifig.pdf')
File "logi.py", line 18, in <module>
plt.plot(r, logi(0.5, r)[i], c='k', lw=0.1)
File "/usr/local/lib/python2.7/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
numba.errors.TypingError: Caused By:
Traceback (most recent call last):
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 235, in run
stage()
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 449, in stage_nopython_frontend
self.locals)
File "/usr/local/lib/python2.7/site-packages/numba/compiler.py", line 805, in type_inference_stage
infer.propagate()
File "/usr/local/lib/python2.7/site-packages/numba/typeinfer.py", line 767, in propagate
raise errors[0]
TypingError: Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
Failed at nopython (nopython frontend)
Invalid usage of BoundFunction(list.append for list(float64)) with parameters (array(float64, 1d, C))
* parameterized
File "logi.py", line 13
[1] During: resolving callee type: BoundFunction(list.append for list(float64))
[2] During: typing of call at logi.py (13)
您的代码有很多问题导致 numba jit-compiler 出现问题:
-
带有参数的
np.sort
无效,也没有在二维数组上使用它 (参见:numpy 支持 特征)x
从浮点数变为数组。 Numba 要求整个函数的类型一致性
下面是一个在 nopython
模式下编译并产生相同结果的 numba 函数。基本上我预先分配存储阵列,因为大小是预先知道的,然后按列排序。不幸的是 numba
没有很好的排序实现,所以你没有得到很大的加速。您可能还可以进行其他性能调整更改。另请注意,在绘图部分的每个循环中调用 logi
然后提取单个值是没有意义的。只需计算一次数组,然后选择您需要的值。
import numpy as np
import matplotlib.pyplot as plt
import numba as nb
N_SPLITS = 1000
%matplotlib inline
def logi_orig(x0, r):
x = x0
for n in range(30000):
x = x * r * (1-x)
final_list = [x]
for n in range(N_SPLITS):
final_list.append(final_list[-1] * r * ( 1 - final_list[-1]))
return np.sort(final_list, axis=0)
@nb.jit(nopython=True)
def logi_nb(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
out = np.empty_like(final_list)
for n in range(r.shape[0]):
out[:,n] = np.sort(final_list[:,n])
return out
def logi(x0, r):
x = np.full_like(r, x0)
for n in range(30000):
x = x * r * (1-x)
final_list = np.empty((N_SPLITS + 1, r.shape[0]), dtype=np.float64)
final_list[0,:] = x
for n in range(1, N_SPLITS + 1):
final_list[n, :] = final_list[n - 1] * r * ( 1 - final_list[n - 1])
return np.sort(final_list, axis=0)
r = np.arange(2.4, 4., .001)
y_orig = logi_orig(0.5, r)
y = logi(0.5, r)
y_nb = logi_nb(0.5, r)
print np.allclose(y, y_orig)
print np.allclose(y_nb, y_orig)
for i in range(N_SPLITS):
plt.plot(r, y[i], c='k', lw=0.1)
OSX (2014 MBP) 使用 Numba v0.34.0 的时间:
%timeit logi_orig(0.5, r)
%timeit logi(0.5, r)
%timeit logi_nb(0.5, r)
10 loops, best of 3: 171 ms per loop
10 loops, best of 3: 168 ms per loop
10 loops, best of 3: 77 ms per loop