子图中的每列颜色条

per-column colorbars in subplots

我正在使用 plt.subplots() 创建一个包含 6 个地块的子图。 我想要每列两个不同的颜色图。第 1 行显示一些近似解,第 2 行显示与真实解的绝对逐点差异。我想为什么不为颜色图添加另一列并从该轴上删除除颜色图之外的所有内容。只能删除 x 和 y 刻度(参见 ax[0,3])。

MATPLOTLIB 有文档 here,但它太具体了。有谁知道做同样事情的捷径。

这是我的代码:

fig, ax = plt.subplots(2, 4,dpi=300)
#ax = fig.add_subplot(111)#, projection='3d')
n_points = 1000
# Intepolation points
x = np.linspace(lb[0], ub[0], n_points)
y = np.linspace(lb[1], ub[1], n_points)
# Create meshgrid
X, Y = np.meshgrid(x,y)
# PINN predictions
new_data = griddata(nodes, u_pred_5k.flatten(), (X, Y), method='linear')
sc1 =  ax[0,0].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal',vmin=0, vmax=1)
new_data = griddata(nodes, u_pred_10k.flatten(), (X, Y), method='linear')
sc1 =  ax[0,1].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal',vmin=0, vmax=1)
new_data = griddata(nodes, u_pred_30k.flatten(), (X, Y), method='linear')
sc1 =  ax[0,2].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal',vmin=0, vmax=1)
# Absolute error
new_data = griddata(nodes, abs(u_pred_5k-temperature).flatten(), (X, Y), method='linear')
sc2 =  ax[1,0].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal')
new_data = griddata(nodes, abs(u_pred_10k-temperature).flatten(), (X, Y), method='linear')
sc2 =  ax[1,1].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal')
new_data = griddata(nodes, abs(u_pred_30k-temperature).flatten(), (X, Y), method='linear')
sc2 =  ax[1,2].imshow(new_data, interpolation='nearest', cmap='jet', 
                  extent=[nodes[:,0].min(), nodes[:,0].max(), nodes[:,1].min(), nodes[:,1].max()], 
                  origin='lower', aspect='equal')
#sc = ax.scatter(nodes[:,0], nodes[:,1], c = best_u_pred, cmap=cm.jet, vmin=0, vmax=1)
#plt.colorbar(sc)
sc1.cmap.set_under('k')
sc1.cmap.set_over('k')
ax[0,0].set(xlabel='x', ylabel='y')
ax[0,1].set(xlabel='x', ylabel='y')
ax[0,2].set(xlabel='x', ylabel='y')
ax[1,0].set(xlabel='x', ylabel='y')
ax[1,1].set(xlabel='x', ylabel='y')
ax[1,2].set(xlabel='x', ylabel='y')
ax[0,0].set_title('5k iterations')
ax[0,1].set_title('10k iterations')
ax[0,2].set_title('30k iterations')
fig.colorbar(sc1, ax=ax[0,3])
ax[0,3].get_xaxis().set_visible(False)
ax[0,3].get_yaxis().set_visible(False)
#plt.xlabel('x')
#plt.ylabel('y')
#ax.set_zlabel('z')
#plt.title('Temperature distribution')
#fig.supxlabel('x')
#fig.supylabel('y')
fig.tight_layout()

这里我不得不拿出一些数据,所以我稍微改变了你的代码,但本质仍然存在。我会做的不同是用 ax[0, 3].axis(False)ax[1, 3].axis(False).

关闭轴
fig, ax = plt.subplots(2, 4)
# Intepolation points
x = np.linspace(-2, 2, 50)
y = np.linspace(-2, 2, 50)
# Create meshgrid
X, Y = np.meshgrid(x,y)
Z1 = np.cos(X**2 + Y**2)
Z2 = np.sin(X**2 + Y**2)

sc = []
c = 0
for i in range(2):
    for j in range(3):
        sc.append(ax[i, j].contourf(X, Y, Z1 if c < 3 else Z2, cmap=cm.viridis if c < 3 else cm.magma))
        ax[i, j].set_aspect("equal")
        c += 1

ax[0, 3].axis(False)
ax[1, 3].axis(False)
plt.colorbar(sc[0], ax=ax[0, 3])
plt.colorbar(sc[3], ax=ax[1, 3])

ax[0,0].set(xlabel='x', ylabel='y')
ax[0,1].set(xlabel='x', ylabel='y')
ax[0,2].set(xlabel='x', ylabel='y')
ax[1,0].set(xlabel='x', ylabel='y')
ax[1,1].set(xlabel='x', ylabel='y')
ax[1,2].set(xlabel='x', ylabel='y')
ax[0,0].set_title('5k iterations')
ax[0,1].set_title('10k iterations')
ax[0,2].set_title('30k iterations')

fig.tight_layout()