防止截断带有 sharex 的不等维的 imshow 子图

prevent cutoff of imshow subplots of unequal dimension w/ sharex

我有几个矩阵想用 imshow 显示在同一图的子图中。它们都具有相同的列数但行数不同。我想:

  1. imshow
  2. 显示时查看每个矩阵的所有内容
  3. 保留imshow
  4. aspect=1效果
  5. 图中使用sharex

(这意味着子图的高度反映了矩阵中不同的行数)。我尝试使用 gridspec(通过 plt.subplotsgridspec_kw 参数)但是 sharexaspect=1 的组合导致部分矩阵被切断,除非我手动调整 window 的大小。示例:

import numpy as np
import matplotlib.pyplot as plt
# fake data
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)

data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']

# initialize figure
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
                        gridspec_kw=dict(height_ratios=nrows))

for ix, d in enumerate(data):
    ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

根据每个矩阵中的行数,我可以猜测应该使它们全部可见的大概图形尺寸,但它不起作用:

figsize = (foo.shape[1], sum(nrows))
fig, axs = plt.subplots(3, 1, squeeze=False, sharex=True,
                        gridspec_kw=dict(height_ratios=nrows),
                        figsize=figsize)

for ix, d in enumerate(data):
    ax = axs[ix % axs.shape[0], ix // axs.shape[0]]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)

注意所有 3 个子图的顶部和底部行是如何被部分截断的(最容易在中间的部分看到)并且顶部和底部图形边距处有大量多余的空白:

使用tight_layout也没有解决;它使子图太大(注意轴脊柱和图像之间每个子图 top/bottom 处的间隙):

有没有办法让imshowsharex在这里和谐工作?

我刚刚发现 ImageGrid,它的效果很好。完整示例:

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
foo = np.arange(5 * 7).reshape(5, 7)
bar = np.arange(11 * 7).reshape(11, 7)
baz = np.arange(3 * 7).reshape(3, 7)
data = [foo, bar, baz]
nrows = [x.shape[0] for x in data]
row_labels = np.array([x for x in 'abcdefghijk'])
col_labels = [x for x in 'ABCDEFG']
fig = plt.figure()
axs = ImageGrid(fig, 111, nrows_ncols=(3, 1), axes_pad=0.1)
for ix, d in enumerate(data):
    ax = axs[ix]
    _ = ax.imshow(d)
    _ = ax.yaxis.set_ticks(range(d.shape[0]))
    _ = ax.xaxis.set_ticks(range(d.shape[1]))
    _ = ax.yaxis.set_ticklabels(row_labels[np.arange(d.shape[0])])
    _ = ax.xaxis.set_ticklabels(col_labels)