Matplotlib 加速将绘图保存到磁盘
Matplotlib speed up saving plots to disk
我想从大约 250 个单独的帧创建一个动画,在具有 4 x 11 子面板的图形中显示绘制为 2D 图像的数据。数据表示作为时间频率和纬度函数的速度功率谱。然而,每一帧的创建和保存大约需要 4 秒,包括 运行 的数据计算时间。在非交互式绘图模式下,我使用 'agg' 作为后端,以避免在交互式绘图功能上花费时间。
这里的速度瓶颈不是要绘制的数据的计算,而是将绘图保存到磁盘。示例 运行 次随机数据(见下面的代码)并且只有 5 帧没有保存图是……。像 5 秒,保存地块 17-19 秒。对于我使用的实际数据,需要绘制更多绘图艺术家(面板上的文本、额外的线图等),但脚本执行时间非常相似。 对于总共约 250 帧,这表示大约 900 秒,因此需要 15 分钟来计算数据然后保存绘图。但是,由于我可能想多次生成类似的帧或使用略有不同的数据,因此最好减少此脚本的执行时间。
下面给出了一个(希望如此)可重现的代码,它使用随机数据,但数据大小等于我使用的实际数据。示例框架(代码生成的第一个框架)也可以在下面找到。在代码中,函数 create_fig()
生成一个带有包含虚拟数据的子面板的图形,并且在 for
不同帧的循环中,仅替换子面板中的数据。
有没有办法加快将绘图保存到 png 文件中的速度?非常感谢任何帮助!
# import packages
import numpy as np
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
path_plots_out = '/home/proxauf'
# set up grids
nt, nlat, nlon = 3328, 24, 48
dlat = 7.5
lats = np.linspace(-90,90-dlat,nlat)
dt = 98191.08
nu = (-1) * np.fft.fftfreq(nt, dt) * 10 ** 9
nnu = len(nu)
nu_fftshift = np.fft.fftshift(nu)
dnu_fftshift = nu_fftshift[1] - nu_fftshift[0]
nu_lims = [-500, 500]
ind_nu_xlims = np.where(np.logical_and(nu_fftshift >= nu_lims[0], nu_fftshift <= nu_lims[1]))[0]
ext_box_nu_lat = [nu_fftshift[ind_nu_xlims][0] - dnu_fftshift / 2, nu_fftshift[ind_nu_xlims][-1] + dnu_fftshift / 2, lats[0] - dlat / 2.0, lats[-1] + dlat / 2.0]
nnu_cut = len(ind_nu_xlims)
plt.ioff()
if plt.rcParams['interactive']:
mpl.use('Qt5Agg')
else:
mpl.use('agg')
# plotting function
def create_fig():
data_xlabels = np.zeros((nrows, ncols), dtype='U30')
data_xlabels[-1, :] = r'Frequency [nHz]'
data_xticks = np.array([[np.linspace(-300, 300, 3)] * ncols] * nrows)
data_xticks_minor = np.array([[np.linspace(-500, 500, 21)] * ncols] * nrows)
data_xlims = np.array([[(-500, 500)] * ncols] * nrows)
data_ylabels = np.zeros((nrows, ncols), dtype='U30')
data_ylabels[:, 0] = r'Latitude [deg]'
data_yticks = np.array([[np.linspace(-90, 90, 7)] * ncols] * nrows)
data_yticks_minor = np.array([[np.linspace(-90, 90, 25)] * ncols] * nrows)
data_ylims = np.array([[(-90, 90)] * ncols] * nrows)
plot_xticks = np.zeros((nrows, ncols), dtype=bool)
plot_xticks[-1, :] = True
plot_yticks = np.zeros((nrows, ncols), dtype=bool)
plot_yticks[:, 0] = True
fig_left, fig_right, fig_bottom, fig_top, fig_hspace, fig_wspace = (0.04, 0.95, 0.06, 0.90, 0.1, 0.1)
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
data_list = []
for i in range(nrows):
data_list_temp = []
for j in range(ncols):
ax = axes[i, j]
im = ax.imshow(np.zeros((nnu_cut, nlat)).T, interpolation='nearest', origin='lower', aspect='auto', cmap='binary', extent=ext_box_nu_lat)
im.set_clim(0,1e4)
ax.set_xlabel(data_xlabels[i, j])
ax.set_ylabel(data_ylabels[i, j])
ax.set_xlim(data_xlims[i, j])
ax.set_ylim(data_ylims[i, j])
ax.set_xticks(data_xticks[i, j])
ax.set_xticks(data_xticks_minor[i, j], minor=True)
ax.set_yticks(data_yticks[i, j])
ax.set_yticks(data_yticks_minor[i, j], minor=True)
if not plot_xticks[i, j]:
ax.tick_params(labelbottom=False)
if not plot_yticks[i, j]:
ax.tick_params(labelleft=False)
data_list_temp.append(im)
data_list.append(data_list_temp)
fig.subplots_adjust(left=fig_left, right=fig_right, bottom=fig_bottom, top=fig_top, hspace=fig_hspace, wspace=fig_wspace)
fig.canvas.draw()
ax1 = axes[0, -1]
ax2 = axes[-1, -1]
top = ax1.get_position().y1
bottom = ax2.get_position().y0
right = ax2.get_position().x1
cbar_pad = 0.01
cbar_width = 0.01
cbar_height = top - bottom
cax = fig.add_axes([right + cbar_pad, bottom, cbar_width, cbar_height])
cbar = plt.colorbar(data_list[-1][-1], ax=axes[-1, -1], cax=cax)
return fig, axes, data_list
nrows = 4
ncols = 11
figsize = (16.5, 8)
# create figure with empty subpanels
fig, axes, data_list = create_fig()
# generate some data
np.random.seed(100)
data1 = np.random.rand(nt,nlat,nlon)
data2 = np.random.rand(nt,nlat,nlon)
data3 = np.random.rand(nt,nlat,nlon)
data4 = np.random.rand(nt,nlat,nlon)
wsize = nt // 4
data1_temp = np.zeros((nt, nlat, nlon))
data2_temp = np.zeros((nt, nlat, nlon))
data3_temp = np.zeros((nt, nlat, nlon))
data4_temp = np.zeros((nt, nlat, nlon))
data1_temp[:wsize,:,:] = data1[:wsize,:,:]
data2_temp[:wsize,:,:] = data2[:wsize,:,:]
data3_temp[:wsize,:,:] = data3[:wsize,:,:]
data4_temp[:wsize,:,:] = data4[:wsize,:,:]
frame_cad = 10
# do not activate, else program will take about 15-20 minutes to finish
# frame_inds = range(0, nt - wsize + 1, frame_cad)
frame_inds = range(0, 50, frame_cad)
t0 = time.time()
for c, i in enumerate(frame_inds):
print(c)
if i >= 1:
# fill in data for the next frame
data1_temp[i-frame_cad:i] = 0.0
data1_temp[i+wsize- 1:i+wsize-1+frame_cad] = data1[i+wsize-1:i+wsize-1+frame_cad,:,:]
data2_temp[i-frame_cad:i] = 0.0
data2_temp[i+wsize- 1:i+wsize-1+frame_cad] = data2[i+wsize-1:i+wsize-1+frame_cad,:,:]
data3_temp[i-frame_cad:i] = 0.0
data3_temp[i+wsize- 1:i+wsize-1+frame_cad] = data3[i+wsize-1:i+wsize-1+frame_cad,:,:]
data4_temp[i-frame_cad:i] = 0.0
data4_temp[i+wsize- 1:i+wsize-1+frame_cad] = data4[i+wsize-1:i+wsize-1+frame_cad,:,:]
# compute power spectrum
pu1_temp = np.abs(np.fft.fftn(data1_temp, axes=(0, 2))) ** 2
pu2_temp = np.abs(np.fft.fftn(data2_temp, axes=(0, 2))) ** 2
pu3_temp = np.abs(np.fft.fftn(data3_temp, axes=(0, 2))) ** 2
pu4_temp = np.abs(np.fft.fftn(data4_temp, axes=(0, 2))) ** 2
pu_temp_list = [pu1_temp, pu2_temp, pu3_temp, pu4_temp]
# update data in subpanels
for s in range(nrows):
for j in range(ncols):
data_list[s][j].set_data(np.fft.fftshift(pu_temp_list[s][:,:,j], axes=(0,))[ind_nu_xlims].T)
# save figure
fig.savefig('%s/Whosebug_test/frame_%04d.png' % (path_plots_out, c))
plt.close()
print(time.time() - t0)
更新:修改了下面给出的代码块(没有小刻度,pyfftw
而不是 numpy
,更快的绝对平方计算;注意:data_list
return 参数从 create_fig()
重命名为 plot_data_list
) 5 帧产生 运行 约 6 秒的宁次。最大的速度提升来自停用小滴答(如 中所述)。
# use np.take_along_axis() with sorting indices instead of np.fft.fftshift() later, gives a slight (not too much!) speed boost
ind_nu_xlims = np.where(np.logical_and(nu >= nu_lims[0], nu <= nu_lims[1]))[0]
ind_nu_sort = np.argsort(nu[ind_nu_xlims])
nu_sort = np.take_along_axis(nu[ind_nu_xlims],ind_nu_sort,axis=0)
ext_box_nu_lat = [nu_sort[0] + dnu_fftshift / 2, nu_sort[-1] - dnu_fftshift / 2, lats[0] - dlat / 2.0, lats[-1] + dlat / 2.0]
# plotting function
def create_fig():
# deactivating ticks massively (!) boosts plotting performance
# ax.set_xticks(data_xticks_minor[i, j], minor=True)
# ax.set_yticks(data_yticks_minor[i, j], minor=True)
data_list = [data1, data2, data3, data4]
# wisdom makes FFTs much faster using pyfftw than using numpy
# enable cache and set cache memory-keeping time sufficiently large
# this depends on the computation time between FFT calls
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(5)
for c, i in enumerate(frame_inds):
print(c)
data_temp_list = [data1_temp, data2_temp, data3_temp, data4_temp]
pu_temp_list = []
for j, data_temp in enumerate(data_temp_list):
if i >= 1:
# fill in data for the next frame
data_temp[i-frame_cad:i] = 0.0
data_temp[i+wsize-1:i+wsize-1+frame_cad] = data_list[j][i+wsize-1:i+wsize-1+frame_cad,:,:]
# compute Fourier transform via pyfftw; wisdom makes FFTs much faster using pyfftw than using numpy
pu_temp = pyfftw.interfaces.numpy_fft.fftn(data_temp, axes=(0, 2), threads=-1)
# compute absolute-square using np.real(x * np.conj(x));
# about same speed as np.real(x) * np.imag(x);
# faster than np.einsum('ijk,ijk->ijk',x,np.conj(x));
# also faster than np.abs(x)**2 since np.abs(x)**2 first takes square-root, then squares again
pu_temp = np.real(pu_temp*np.conj(pu_temp))
pu_temp_list.append(pu_temp)
# update data in subpanels
for s in range(nrows):
for j in range(ncols):
# use np.take_along_axis() with sorting indices instead of np.fft.fftshift(), gives a slight (not too much!) speed boost
plot_data_list[s][j].set_data(np.take_along_axis(pu_temp_list[s][ind_nu_xlims,:,j], ind_nu_sort[:,None], axis=0).T)
# save figure
fig.savefig('%s/Whosebug_test/frame_%04d.png' % (path_plots_out, c))
plt.close()
print(time.time() - t0)
我会给你一些提示,但不能解决:
你在矩阵上运行做了正确的事情,但是检查是否可以最大化缓存转置你的矩阵(当你有一个非常高和窄的情况时)
您听说过稀疏矩阵或矩阵压缩技术吗?
在 for 循环之外执行当 i<1 时您需要执行的操作 - 如果您将其移除,您将节省 1 次比较
可以使用并行计算吗?喜欢 python?
的 Omp
因此,如果这正是您想要的情节,那么我认为您正在以最快的速度完成。 5位数扣15秒,不存扣5秒
信不信由你,让它更快的简单方法是减少你的小滴答声。如果我将这些行注释掉,我会得到 8 秒,加速 70%。在 matplotlib 中,刻度非常昂贵。鉴于您的小滴答声很小,我建议将其作为一种简单的优化。
我想从大约 250 个单独的帧创建一个动画,在具有 4 x 11 子面板的图形中显示绘制为 2D 图像的数据。数据表示作为时间频率和纬度函数的速度功率谱。然而,每一帧的创建和保存大约需要 4 秒,包括 运行 的数据计算时间。在非交互式绘图模式下,我使用 'agg' 作为后端,以避免在交互式绘图功能上花费时间。
这里的速度瓶颈不是要绘制的数据的计算,而是将绘图保存到磁盘。示例 运行 次随机数据(见下面的代码)并且只有 5 帧没有保存图是……。像 5 秒,保存地块 17-19 秒。对于我使用的实际数据,需要绘制更多绘图艺术家(面板上的文本、额外的线图等),但脚本执行时间非常相似。 对于总共约 250 帧,这表示大约 900 秒,因此需要 15 分钟来计算数据然后保存绘图。但是,由于我可能想多次生成类似的帧或使用略有不同的数据,因此最好减少此脚本的执行时间。
下面给出了一个(希望如此)可重现的代码,它使用随机数据,但数据大小等于我使用的实际数据。示例框架(代码生成的第一个框架)也可以在下面找到。在代码中,函数 create_fig()
生成一个带有包含虚拟数据的子面板的图形,并且在 for
不同帧的循环中,仅替换子面板中的数据。
有没有办法加快将绘图保存到 png 文件中的速度?非常感谢任何帮助!
# import packages
import numpy as np
import time
import matplotlib as mpl
import matplotlib.pyplot as plt
path_plots_out = '/home/proxauf'
# set up grids
nt, nlat, nlon = 3328, 24, 48
dlat = 7.5
lats = np.linspace(-90,90-dlat,nlat)
dt = 98191.08
nu = (-1) * np.fft.fftfreq(nt, dt) * 10 ** 9
nnu = len(nu)
nu_fftshift = np.fft.fftshift(nu)
dnu_fftshift = nu_fftshift[1] - nu_fftshift[0]
nu_lims = [-500, 500]
ind_nu_xlims = np.where(np.logical_and(nu_fftshift >= nu_lims[0], nu_fftshift <= nu_lims[1]))[0]
ext_box_nu_lat = [nu_fftshift[ind_nu_xlims][0] - dnu_fftshift / 2, nu_fftshift[ind_nu_xlims][-1] + dnu_fftshift / 2, lats[0] - dlat / 2.0, lats[-1] + dlat / 2.0]
nnu_cut = len(ind_nu_xlims)
plt.ioff()
if plt.rcParams['interactive']:
mpl.use('Qt5Agg')
else:
mpl.use('agg')
# plotting function
def create_fig():
data_xlabels = np.zeros((nrows, ncols), dtype='U30')
data_xlabels[-1, :] = r'Frequency [nHz]'
data_xticks = np.array([[np.linspace(-300, 300, 3)] * ncols] * nrows)
data_xticks_minor = np.array([[np.linspace(-500, 500, 21)] * ncols] * nrows)
data_xlims = np.array([[(-500, 500)] * ncols] * nrows)
data_ylabels = np.zeros((nrows, ncols), dtype='U30')
data_ylabels[:, 0] = r'Latitude [deg]'
data_yticks = np.array([[np.linspace(-90, 90, 7)] * ncols] * nrows)
data_yticks_minor = np.array([[np.linspace(-90, 90, 25)] * ncols] * nrows)
data_ylims = np.array([[(-90, 90)] * ncols] * nrows)
plot_xticks = np.zeros((nrows, ncols), dtype=bool)
plot_xticks[-1, :] = True
plot_yticks = np.zeros((nrows, ncols), dtype=bool)
plot_yticks[:, 0] = True
fig_left, fig_right, fig_bottom, fig_top, fig_hspace, fig_wspace = (0.04, 0.95, 0.06, 0.90, 0.1, 0.1)
fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
data_list = []
for i in range(nrows):
data_list_temp = []
for j in range(ncols):
ax = axes[i, j]
im = ax.imshow(np.zeros((nnu_cut, nlat)).T, interpolation='nearest', origin='lower', aspect='auto', cmap='binary', extent=ext_box_nu_lat)
im.set_clim(0,1e4)
ax.set_xlabel(data_xlabels[i, j])
ax.set_ylabel(data_ylabels[i, j])
ax.set_xlim(data_xlims[i, j])
ax.set_ylim(data_ylims[i, j])
ax.set_xticks(data_xticks[i, j])
ax.set_xticks(data_xticks_minor[i, j], minor=True)
ax.set_yticks(data_yticks[i, j])
ax.set_yticks(data_yticks_minor[i, j], minor=True)
if not plot_xticks[i, j]:
ax.tick_params(labelbottom=False)
if not plot_yticks[i, j]:
ax.tick_params(labelleft=False)
data_list_temp.append(im)
data_list.append(data_list_temp)
fig.subplots_adjust(left=fig_left, right=fig_right, bottom=fig_bottom, top=fig_top, hspace=fig_hspace, wspace=fig_wspace)
fig.canvas.draw()
ax1 = axes[0, -1]
ax2 = axes[-1, -1]
top = ax1.get_position().y1
bottom = ax2.get_position().y0
right = ax2.get_position().x1
cbar_pad = 0.01
cbar_width = 0.01
cbar_height = top - bottom
cax = fig.add_axes([right + cbar_pad, bottom, cbar_width, cbar_height])
cbar = plt.colorbar(data_list[-1][-1], ax=axes[-1, -1], cax=cax)
return fig, axes, data_list
nrows = 4
ncols = 11
figsize = (16.5, 8)
# create figure with empty subpanels
fig, axes, data_list = create_fig()
# generate some data
np.random.seed(100)
data1 = np.random.rand(nt,nlat,nlon)
data2 = np.random.rand(nt,nlat,nlon)
data3 = np.random.rand(nt,nlat,nlon)
data4 = np.random.rand(nt,nlat,nlon)
wsize = nt // 4
data1_temp = np.zeros((nt, nlat, nlon))
data2_temp = np.zeros((nt, nlat, nlon))
data3_temp = np.zeros((nt, nlat, nlon))
data4_temp = np.zeros((nt, nlat, nlon))
data1_temp[:wsize,:,:] = data1[:wsize,:,:]
data2_temp[:wsize,:,:] = data2[:wsize,:,:]
data3_temp[:wsize,:,:] = data3[:wsize,:,:]
data4_temp[:wsize,:,:] = data4[:wsize,:,:]
frame_cad = 10
# do not activate, else program will take about 15-20 minutes to finish
# frame_inds = range(0, nt - wsize + 1, frame_cad)
frame_inds = range(0, 50, frame_cad)
t0 = time.time()
for c, i in enumerate(frame_inds):
print(c)
if i >= 1:
# fill in data for the next frame
data1_temp[i-frame_cad:i] = 0.0
data1_temp[i+wsize- 1:i+wsize-1+frame_cad] = data1[i+wsize-1:i+wsize-1+frame_cad,:,:]
data2_temp[i-frame_cad:i] = 0.0
data2_temp[i+wsize- 1:i+wsize-1+frame_cad] = data2[i+wsize-1:i+wsize-1+frame_cad,:,:]
data3_temp[i-frame_cad:i] = 0.0
data3_temp[i+wsize- 1:i+wsize-1+frame_cad] = data3[i+wsize-1:i+wsize-1+frame_cad,:,:]
data4_temp[i-frame_cad:i] = 0.0
data4_temp[i+wsize- 1:i+wsize-1+frame_cad] = data4[i+wsize-1:i+wsize-1+frame_cad,:,:]
# compute power spectrum
pu1_temp = np.abs(np.fft.fftn(data1_temp, axes=(0, 2))) ** 2
pu2_temp = np.abs(np.fft.fftn(data2_temp, axes=(0, 2))) ** 2
pu3_temp = np.abs(np.fft.fftn(data3_temp, axes=(0, 2))) ** 2
pu4_temp = np.abs(np.fft.fftn(data4_temp, axes=(0, 2))) ** 2
pu_temp_list = [pu1_temp, pu2_temp, pu3_temp, pu4_temp]
# update data in subpanels
for s in range(nrows):
for j in range(ncols):
data_list[s][j].set_data(np.fft.fftshift(pu_temp_list[s][:,:,j], axes=(0,))[ind_nu_xlims].T)
# save figure
fig.savefig('%s/Whosebug_test/frame_%04d.png' % (path_plots_out, c))
plt.close()
print(time.time() - t0)
更新:修改了下面给出的代码块(没有小刻度,pyfftw
而不是 numpy
,更快的绝对平方计算;注意:data_list
return 参数从 create_fig()
重命名为 plot_data_list
) 5 帧产生 运行 约 6 秒的宁次。最大的速度提升来自停用小滴答(如
# use np.take_along_axis() with sorting indices instead of np.fft.fftshift() later, gives a slight (not too much!) speed boost
ind_nu_xlims = np.where(np.logical_and(nu >= nu_lims[0], nu <= nu_lims[1]))[0]
ind_nu_sort = np.argsort(nu[ind_nu_xlims])
nu_sort = np.take_along_axis(nu[ind_nu_xlims],ind_nu_sort,axis=0)
ext_box_nu_lat = [nu_sort[0] + dnu_fftshift / 2, nu_sort[-1] - dnu_fftshift / 2, lats[0] - dlat / 2.0, lats[-1] + dlat / 2.0]
# plotting function
def create_fig():
# deactivating ticks massively (!) boosts plotting performance
# ax.set_xticks(data_xticks_minor[i, j], minor=True)
# ax.set_yticks(data_yticks_minor[i, j], minor=True)
data_list = [data1, data2, data3, data4]
# wisdom makes FFTs much faster using pyfftw than using numpy
# enable cache and set cache memory-keeping time sufficiently large
# this depends on the computation time between FFT calls
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(5)
for c, i in enumerate(frame_inds):
print(c)
data_temp_list = [data1_temp, data2_temp, data3_temp, data4_temp]
pu_temp_list = []
for j, data_temp in enumerate(data_temp_list):
if i >= 1:
# fill in data for the next frame
data_temp[i-frame_cad:i] = 0.0
data_temp[i+wsize-1:i+wsize-1+frame_cad] = data_list[j][i+wsize-1:i+wsize-1+frame_cad,:,:]
# compute Fourier transform via pyfftw; wisdom makes FFTs much faster using pyfftw than using numpy
pu_temp = pyfftw.interfaces.numpy_fft.fftn(data_temp, axes=(0, 2), threads=-1)
# compute absolute-square using np.real(x * np.conj(x));
# about same speed as np.real(x) * np.imag(x);
# faster than np.einsum('ijk,ijk->ijk',x,np.conj(x));
# also faster than np.abs(x)**2 since np.abs(x)**2 first takes square-root, then squares again
pu_temp = np.real(pu_temp*np.conj(pu_temp))
pu_temp_list.append(pu_temp)
# update data in subpanels
for s in range(nrows):
for j in range(ncols):
# use np.take_along_axis() with sorting indices instead of np.fft.fftshift(), gives a slight (not too much!) speed boost
plot_data_list[s][j].set_data(np.take_along_axis(pu_temp_list[s][ind_nu_xlims,:,j], ind_nu_sort[:,None], axis=0).T)
# save figure
fig.savefig('%s/Whosebug_test/frame_%04d.png' % (path_plots_out, c))
plt.close()
print(time.time() - t0)
我会给你一些提示,但不能解决:
你在矩阵上运行做了正确的事情,但是检查是否可以最大化缓存转置你的矩阵(当你有一个非常高和窄的情况时)
您听说过稀疏矩阵或矩阵压缩技术吗?
在 for 循环之外执行当 i<1 时您需要执行的操作 - 如果您将其移除,您将节省 1 次比较
可以使用并行计算吗?喜欢 python?
的 Omp
因此,如果这正是您想要的情节,那么我认为您正在以最快的速度完成。 5位数扣15秒,不存扣5秒
信不信由你,让它更快的简单方法是减少你的小滴答声。如果我将这些行注释掉,我会得到 8 秒,加速 70%。在 matplotlib 中,刻度非常昂贵。鉴于您的小滴答声很小,我建议将其作为一种简单的优化。