将热图与 subplot2grid 结合使用时确保列宽一致
Ensure consistent column widths when using heatmap with subplot2grid
我正在尝试格式化我的子图,但是,出于某种原因,我无法弄清楚为什么所有子图的位置都保持不变。现在,它们看起来像这样:
如您所见,我有两个问题:1.我不知道如何排除文本标签(如“日期”)和 2.我需要格式化子图以共享同一轴,所以他们保持一致。到目前为止我的代码:
fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid((23,20), (0,0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((23,20), (19,0), colspan=19, rowspan=1)
sns.set(font_scale=0.95)
sns.heatmap(pivot, ax= ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues")
sns.heatmap((pd.DataFrame(pivot.sum(axis=0))).transpose(), ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues", xticklabels=False, yticklabels=False)
plt.show()
我的数据框是这样的:
dates 2020Q1 2020Q2 2020Q3 2020Q4 2021Q1 2021Q2 2021Q3
inicio
2020Q1 56.0 45.0 15.0 7.0 4.0 4.0 3.0
2020Q2 NaN 418.0 277.0 86.0 46.0 33.0 28.0
2020Q3 NaN NaN 619.0 398.0 167.0 122.0 93.0
2020Q4 NaN NaN NaN 1163.0 916.0 521.0 319.0
2021Q1 NaN NaN NaN NaN 976.0 680.0 363.0
2021Q2 NaN NaN NaN NaN NaN 811.0 559.0
2021Q3 NaN NaN NaN NaN NaN NaN 1879.0
- 在
seaborn.heatmap
中将 square=True
更改为 square=False
将使所有列具有相同的宽度。
- 可以通过将标签设置为空字符串来删除标签:
ax1.set(xlabel='', ylabel='')
- 测试于
python 3.8.11
、pandas 1.3.3
、matplotlib 3.4.3
、seaborn 0.11.2
import panda as pd
# test dataframe
data = {'dates': ['2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q4', '2020Q4', '2020Q4', '2020Q4', '2021Q1', '2021Q1', '2021Q1', '2021Q2', '2021Q2', '2021Q3'],
'inicio': ['2020Q1', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2021Q1', '2021Q2', '2021Q3', '2021Q2', '2021Q3', '2021Q3'],
'values': [56.0, 45.0, 15.0, 7.0, 4.0, 4.0, 3.0, 418.0, 277.0, 86.0, 46.0, 33.0, 28.0, 619.0, 398.0, 167.0, 122.0, 93.0, 1163.0, 916.0, 521.0, 319.0, 976.0, 680.0, 363.0, 811.0, 559.0, 1879.0]}
df = pd.DataFrame(data)
# pivot the dataframe
pv = df.pivot(index='dates', columns='inicio', values='values')
# create figure and subplots
fig = plt.figure(figsize=(20, 10))
ax1 = plt.subplot2grid((20, 10), (0, 0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((20, 10), (19, 0), colspan=19, rowspan=1)
sns.set(font_scale=0.95)
# create heatmap with square=False instead of True
sns.heatmap(pv, ax=ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues")
sns.heatmap(pv.sum().to_frame().T, ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues", xticklabels=False, yticklabels=False)
ax1.set_yticklabels(pv.columns, rotation=0) # rotate the yticklabels
ax1.set(xlabel='', ylabel='') # remove x & y labels
ax2.set(xlabel='', ylabel='') # remove x & y labels
plt.show()
我正在尝试格式化我的子图,但是,出于某种原因,我无法弄清楚为什么所有子图的位置都保持不变。现在,它们看起来像这样:
如您所见,我有两个问题:1.我不知道如何排除文本标签(如“日期”)和 2.我需要格式化子图以共享同一轴,所以他们保持一致。到目前为止我的代码:
fig = plt.figure(figsize=(25, 15))
ax1 = plt.subplot2grid((23,20), (0,0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((23,20), (19,0), colspan=19, rowspan=1)
sns.set(font_scale=0.95)
sns.heatmap(pivot, ax= ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues")
sns.heatmap((pd.DataFrame(pivot.sum(axis=0))).transpose(), ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=True, cmap="Blues", xticklabels=False, yticklabels=False)
plt.show()
我的数据框是这样的:
dates 2020Q1 2020Q2 2020Q3 2020Q4 2021Q1 2021Q2 2021Q3
inicio
2020Q1 56.0 45.0 15.0 7.0 4.0 4.0 3.0
2020Q2 NaN 418.0 277.0 86.0 46.0 33.0 28.0
2020Q3 NaN NaN 619.0 398.0 167.0 122.0 93.0
2020Q4 NaN NaN NaN 1163.0 916.0 521.0 319.0
2021Q1 NaN NaN NaN NaN 976.0 680.0 363.0
2021Q2 NaN NaN NaN NaN NaN 811.0 559.0
2021Q3 NaN NaN NaN NaN NaN NaN 1879.0
- 在
seaborn.heatmap
中将square=True
更改为square=False
将使所有列具有相同的宽度。 - 可以通过将标签设置为空字符串来删除标签:
ax1.set(xlabel='', ylabel='')
- 测试于
python 3.8.11
、pandas 1.3.3
、matplotlib 3.4.3
、seaborn 0.11.2
import panda as pd
# test dataframe
data = {'dates': ['2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q1', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q2', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q3', '2020Q4', '2020Q4', '2020Q4', '2020Q4', '2021Q1', '2021Q1', '2021Q1', '2021Q2', '2021Q2', '2021Q3'],
'inicio': ['2020Q1', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q2', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2020Q4', '2021Q1', '2021Q2', '2021Q3', '2021Q1', '2021Q2', '2021Q3', '2021Q2', '2021Q3', '2021Q3'],
'values': [56.0, 45.0, 15.0, 7.0, 4.0, 4.0, 3.0, 418.0, 277.0, 86.0, 46.0, 33.0, 28.0, 619.0, 398.0, 167.0, 122.0, 93.0, 1163.0, 916.0, 521.0, 319.0, 976.0, 680.0, 363.0, 811.0, 559.0, 1879.0]}
df = pd.DataFrame(data)
# pivot the dataframe
pv = df.pivot(index='dates', columns='inicio', values='values')
# create figure and subplots
fig = plt.figure(figsize=(20, 10))
ax1 = plt.subplot2grid((20, 10), (0, 0), colspan=19, rowspan=17)
ax2 = plt.subplot2grid((20, 10), (19, 0), colspan=19, rowspan=1)
sns.set(font_scale=0.95)
# create heatmap with square=False instead of True
sns.heatmap(pv, ax=ax1, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues")
sns.heatmap(pv.sum().to_frame().T, ax=ax2, annot=True, fmt=".0f", robust=True, linewidth=0.1, square=False, cmap="Blues", xticklabels=False, yticklabels=False)
ax1.set_yticklabels(pv.columns, rotation=0) # rotate the yticklabels
ax1.set(xlabel='', ylabel='') # remove x & y labels
ax2.set(xlabel='', ylabel='') # remove x & y labels
plt.show()