如何使用适合的图像和剖面图创建 Matplotlib 图?

How to create Matplotlib figure with image and profile plots that fit together?

我想将 2d 数据绘制为图像,并在下方和侧面显示沿 x 和 y 轴的剖面图。这是一种非常常见的显示数据的方式,因此可能有一种更简单的方法来实现这一点。我想找到最简单、最可靠的方法来正确地做到这一点,并且不使用 matplotlib 之外的任何东西(尽管我有兴趣了解其他可能特别相关的包)。特别是,如果数据的形状(纵横比)发生变化,该方法应该可以在不改变任何东西的情况下工作。

我的主要问题是让副图正确缩放,以便它们的边界与主图匹配。

示例代码:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
# generate grid and test data
x, y = np.linspace(-3,3,300), np.linspace(-1,1,100)
X, Y = np.meshgrid(x,y)
def f(x,y) :
    return np.exp(-(x**2/4+y**2)/.2)*np.cos((x**2+y**2)*10)**2
data = f(X,Y)

# 2d image plot with profiles
h, w = data.shape
gs = gridspec.GridSpec(2, 2,width_ratios=[w,w*.2], height_ratios=[h,h*.2])
ax = [plt.subplot(gs[0]),plt.subplot(gs[1]),plt.subplot(gs[2])]
bounds = [x.min(),x.max(),y.min(),y.max()]
ax[0].imshow(data, cmap='gray', extent = bounds, origin='lower')
ax[1].plot(data[:,w/2],Y[:,w/2],'.',data[:,w/2],Y[:,w/2])
ax[1].axis([data[:,w/2].max(), data[:,w/2].min(), Y.min(), Y.max()])
ax[2].plot(X[h/2,:],data[h/2,:],'.',X[h/2,:],data[h/2,:])
plt.show()

从下面的输出中可以看出,图像向右缩放的方式与边界不匹配。

部分解决方案:

1) 手动调整图形大小以找到正确的宽高比以使其正确显示(可以自动使用图像比率 + 填充 + 使用的宽度比率吗?)。当已经有这么多本应自动处理这些事情的包装选项时,这似乎很俗气。 编辑: 如果未更改填充,plt.gcf().set_figheight(f.get_figwidth()*h/w) 似乎有效。

2) 添加 ax[0].set_aspect('auto') ,然后使边界对齐,但图像不再具有正确的纵横比。

上面代码示例的输出:

我无法使用 subplotgridspec 生成您的布局,同时仍保留 (1) 轴的比率和 (2) 施加在轴上的限制.另一种解决方案是将轴手动放置在图形中,并相应地控制图形的大小(正如您在 OP 中已经提到的那样)。虽然这比使用 subplotgridspec 需要更多的工作,但这种方法仍然非常简单,并且可以非常强大和灵活地生成需要精细控制边距和轴位置的复杂布局.

下面的示例展示了如何通过根据给定轴的大小设置图形的大小来实现这一点。相反,也可以将坐标轴放入预定义尺寸的图形中。然后通过使用图形边距作为缓冲区来保持轴的纵横比。

import numpy as np
import matplotlib.pyplot as plt

plt.close('all')

#------------------------------------------------------------ generate data ----

# generate grid and test data
x, y = np.linspace(-3, 3, 300), np.linspace(-1, 1, 100)
X, Y = np.meshgrid(x,y)
def f(x,y) :
    return np.exp(-(x**2/4+y**2)/.2)*np.cos((x**2+y**2)*10)**2
data = f(X,Y)

# 2d image plot with profiles
h, w = data.shape
data_ratio = h / float(w)

#------------------------------------------------------------ create figure ----

#--- define axes lenght in inches ----

width_ax0 = 8.
width_ax1 = 2.
height_ax2 = 2.

height_ax0 = width_ax0 * data_ratio

#---- define margins size in inches ----

left_margin  = 0.65
right_margin = 0.2
bottom_margin = 0.5
top_margin = 0.25
inter_margin = 0.5

#--- calculate total figure size in inches ----

fwidth = left_margin + right_margin + inter_margin + width_ax0 + width_ax1
fheight = bottom_margin + top_margin + inter_margin + height_ax0 + height_ax2

fig = plt.figure(figsize=(fwidth, fheight))
fig.patch.set_facecolor('white')

#---------------------------------------------------------------- create axe----

ax0 = fig.add_axes([left_margin / fwidth,
                    (bottom_margin + inter_margin + height_ax2) / fheight,
                    width_ax0 / fwidth, height_ax0 / fheight])

ax1 = fig.add_axes([(left_margin + width_ax0 + inter_margin) / fwidth,
                    (bottom_margin + inter_margin + height_ax2) / fheight,
                     width_ax1 / fwidth, height_ax0 / fheight])

ax2 = fig.add_axes([left_margin / fwidth, bottom_margin / fheight,
                    width_ax0 / fwidth, height_ax2 / fheight])

#---------------------------------------------------------------- plot data ----

bounds = [x.min(),x.max(),y.min(),y.max()]
ax0.imshow(data, cmap='gray', extent = bounds, origin='lower')
ax1.plot(data[:,w/2],Y[:,w/2],'.',data[:,w/2],Y[:,w/2])
ax1.invert_xaxis()
ax2.plot(X[h/2,:], data[h/2,:], '.', X[h/2,:], data[h/2,:])

plt.show(block=False)
fig.savefig('subplot_layout.png')

这导致:

您可以使用 sharexsharey 来执行此操作,将 ax= 行替换为:

ax = [plt.subplot(gs[0]),]
ax.append(plt.subplot(gs[1], sharey=ax[0]))
ax.append(plt.subplot(gs[2], sharex=ax[0]))

有趣的是,sharexsharey 的解决方案对我不起作用。它们对齐轴范围但 不是轴长度 ! 为了让它们可靠地对齐,我添加了:

pos  = ax[0].get_position()
pos1 = ax[1].get_position()
pos2 = ax[2].get_position()
ax[1].set_position([pos1.x0,pos.y0,pos1.width,pos.height])
ax[2].set_position([pos.x0,pos2.y0,pos.width,pos2.height])

因此,结合 CT Zhu 先前的回答,这使得:

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
# generate grid and test data
x, y = np.linspace(-3,3,300), np.linspace(-1,1,100)
X, Y = np.meshgrid(x,y)
def f(x,y) :
    return np.exp(-(x**2/4+y**2)/.2)*np.cos((x**2+y**2)*10)**2

data = f(X,Y)

# 2d image plot with profiles
h, w = data.shape
gs = gridspec.GridSpec(2, 2,width_ratios=[w,w*.2], height_ratios=[h,h*.2])
ax = [plt.subplot(gs[0]),]
ax.append(plt.subplot(gs[1], sharey=ax[0]))
ax.append(plt.subplot(gs[2], sharex=ax[0]))
bounds = [x.min(),x.max(),y.min(),y.max()]
ax[0].imshow(data, cmap='gray', extent = bounds, origin='lower')
ax[1].plot(data[:,int(w/2)],Y[:,int(w/2)],'.',data[:,int(w/2)],Y[:,int(w/2)])
ax[1].axis([data[:,int(w/2)].max(), data[:,int(w/2)].min(), Y.min(), Y.max()])
ax[2].plot(X[int(h/2),:],data[int(h/2),:],'.',X[int(h/2),:],data[int(h/2),:])
pos  = ax[0].get_position()
pos1 = ax[1].get_position()
pos2 = ax[2].get_position()
ax[1].set_position([pos1.x0,pos.y0,pos1.width,pos.height])
ax[2].set_position([pos.x0,pos2.y0,pos.width,pos2.height])
plt.show()