将热图与 subplot2grid 结合使用时确保列宽一致

Ensure consistent column widths when using heatmap with subplot2grid

我正在尝试格式化我的子图,但是,出于某种原因,我无法弄清楚为什么所有子图的位置都保持不变。现在,它们看起来像这样:

如您所见,我有两个问题:1.我不知道如何排除文本标签(如“日期”)和 2.我需要格式化子图以共享同一轴,所以他们保持一致。到目前为止我的代码:

fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid((23,20), (0,0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((23,20), (19,0), colspan=19, rowspan=1)

sns.set(font_scale=0.95)

sns.heatmap(pivot, ax= ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues")
sns.heatmap((pd.DataFrame(pivot.sum(axis=0))).transpose(), ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues", xticklabels=False, yticklabels=False)

plt.show()

我的数据框是这样的:

dates   2020Q1  2020Q2  2020Q3  2020Q4  2021Q1  2021Q2  2021Q3
inicio                                                        
2020Q1    56.0    45.0    15.0     7.0     4.0     4.0     3.0
2020Q2     NaN   418.0   277.0    86.0    46.0    33.0    28.0
2020Q3     NaN     NaN   619.0   398.0   167.0   122.0    93.0
2020Q4     NaN     NaN     NaN  1163.0   916.0   521.0   319.0
2021Q1     NaN     NaN     NaN     NaN   976.0   680.0   363.0
2021Q2     NaN     NaN     NaN     NaN     NaN   811.0   559.0
2021Q3     NaN     NaN     NaN     NaN     NaN     NaN  1879.0
  • seaborn.heatmap 中将 square=True 更改为 square=False 将使所有列具有相同的宽度。
  • 可以通过将标签设置为空字符串来删除标签:ax1.set(xlabel='', ylabel='')
  • 测试于 python 3.8.11pandas 1.3.3matplotlib 3.4.3seaborn 0.11.2
import panda as pd

# test dataframe
data = {'dates': ['2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q4', '2020Q4', '2020Q4', '2020Q4', '2021Q1', '2021Q1', '2021Q1', '2021Q2', '2021Q2', '2021Q3'],
        'inicio': ['2020Q1', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2021Q1', '2021Q2', '2021Q3', '2021Q2', '2021Q3', '2021Q3'],
        'values': [56.0, 45.0, 15.0, 7.0, 4.0, 4.0, 3.0, 418.0, 277.0, 86.0, 46.0, 33.0, 28.0, 619.0, 398.0, 167.0, 122.0, 93.0, 1163.0, 916.0, 521.0, 319.0, 976.0, 680.0, 363.0, 811.0, 559.0, 1879.0]}
df = pd.DataFrame(data)

# pivot the dataframe
pv = df.pivot(index='dates', columns='inicio', values='values')

# create figure and subplots
fig = plt.figure(figsize=(20, 10))
ax1 = plt.subplot2grid((20, 10), (0, 0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((20, 10), (19, 0), colspan=19, rowspan=1)

sns.set(font_scale=0.95)

# create heatmap with square=False instead of True
sns.heatmap(pv, ax=ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues")
sns.heatmap(pv.sum().to_frame().T, ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues", xticklabels=False, yticklabels=False)

ax1.set_yticklabels(pv.columns, rotation=0)  # rotate the yticklabels
ax1.set(xlabel='', ylabel='')  # remove x & y labels
ax2.set(xlabel='', ylabel='')  # remove x & y labels

plt.show()