如何为 seaborn histplot 绘制不均匀数量的子图
How to plot uneven number of subplots for seaborn histplot
我目前有一个 13 列的列表,我正在绘制分布图。我想创建一系列子图,以便这些图占用更少 space 但我很难在循环中这样做。
示例数据帧:
import pandas as pd
import numpy as np
data = {'identifier': ['A', 'B', 'C', 'D'],
'treatment': ['untreated', 'treated', 'untreated', 'treated'], 'treatment_timing': ['pre', 'pre', 'post', 'post'],
'subject_A': [1.3, 0.0, 0.5, 1.6], 'subject_B': [2.0, 1.4, 0.0, 0.0], 'subject_C': [nan, 3.0, 2.0, 0.5],
'subject_D': [np.nan, np.nan, 1.0, 1.6], 'subject_E': [0, 0, 0, 0], 'subject_F': [1.0, 1.0, 0.4, 0.5]}
df = pd.DataFrame(data)
identifier treatment treatment_timing subject_A subject_B subject_C subject_D subject_E subject_F
0 A untreated pre 1.3 2.0 NaN NaN 0 1.0
1 B treated pre 0.0 1.4 3.0 NaN 0 1.0
2 C untreated post 0.5 0.0 2.0 1.0 0 0.4
3 D treated post 1.6 0.0 0.5 1.6 0 0.5
- 它从 subject_A 到 subject_M(总共 13 个)。
- 我目前所做的是生成 13 行 1 列的 13 直方图布局。每个主题一个,分为 3 种颜色(预、post 和缺失)。
这是我目前拥有的:
fig, axes = plt.subplots(3,5, sharex=True, figsize=(12,6))
for index, col in enumerate(COL_LIST):
sns.histplot(
df ,x=col, hue="time", multiple="dodge", bins=10, ax=axes[index,index % 3]
).set_title(col.replace("_", " "))
plt.tight_layout()
这绝对不行。但我不确定是否有一种简单的方法来定义轴,而无需复制和粘贴此行 13 次并手动定义轴坐标。
使用 displot 有点麻烦,因为 col_wrap 会出错
ValueError: Number of rows must be a positive integer, not 0
(我相信这是由于np.nan的存在)
- 使用
seaborn.displot
, which is a FacetGrid 比 seaborn.histplot
更容易。
- 探索使用
row
、col
和 col_wrap
来获取所需的行数和列数。
- 必须堆叠
subject_
列,以将数据帧转换为整齐的格式,这可以通过 .stack
完成
import pandas as pd
import seaborn as sns
# convert the dataframe into a long form with stack
df_long = df.set_index(['identifier', 'treatment', 'treatment_timing']).stack().reset_index().rename(columns={'level_3': 'subject', 0: 'vals'})
# sort by subject
df_long = df_long.sort_values('subject').reset_index(drop=True)
# display(df_long.head())
identifier treatment treatment_timing subject vals
0 A untreated pre subject_A 1.3
1 D treated post subject_A 1.6
2 C untreated post subject_A 0.5
3 B treated pre subject_A 0.0
4 D treated post subject_B 0.0
# plot with displot
sns.displot(data=df_long, row='subject', col='treatment', x='vals', hue='treatment_timing', bins=10)
我目前有一个 13 列的列表,我正在绘制分布图。我想创建一系列子图,以便这些图占用更少 space 但我很难在循环中这样做。
示例数据帧:
import pandas as pd
import numpy as np
data = {'identifier': ['A', 'B', 'C', 'D'],
'treatment': ['untreated', 'treated', 'untreated', 'treated'], 'treatment_timing': ['pre', 'pre', 'post', 'post'],
'subject_A': [1.3, 0.0, 0.5, 1.6], 'subject_B': [2.0, 1.4, 0.0, 0.0], 'subject_C': [nan, 3.0, 2.0, 0.5],
'subject_D': [np.nan, np.nan, 1.0, 1.6], 'subject_E': [0, 0, 0, 0], 'subject_F': [1.0, 1.0, 0.4, 0.5]}
df = pd.DataFrame(data)
identifier treatment treatment_timing subject_A subject_B subject_C subject_D subject_E subject_F
0 A untreated pre 1.3 2.0 NaN NaN 0 1.0
1 B treated pre 0.0 1.4 3.0 NaN 0 1.0
2 C untreated post 0.5 0.0 2.0 1.0 0 0.4
3 D treated post 1.6 0.0 0.5 1.6 0 0.5
- 它从 subject_A 到 subject_M(总共 13 个)。
- 我目前所做的是生成 13 行 1 列的 13 直方图布局。每个主题一个,分为 3 种颜色(预、post 和缺失)。
这是我目前拥有的:
fig, axes = plt.subplots(3,5, sharex=True, figsize=(12,6))
for index, col in enumerate(COL_LIST):
sns.histplot(
df ,x=col, hue="time", multiple="dodge", bins=10, ax=axes[index,index % 3]
).set_title(col.replace("_", " "))
plt.tight_layout()
这绝对不行。但我不确定是否有一种简单的方法来定义轴,而无需复制和粘贴此行 13 次并手动定义轴坐标。
使用 displot 有点麻烦,因为 col_wrap 会出错
ValueError: Number of rows must be a positive integer, not 0
(我相信这是由于np.nan的存在)
- 使用
seaborn.displot
, which is a FacetGrid 比seaborn.histplot
更容易。- 探索使用
row
、col
和col_wrap
来获取所需的行数和列数。
- 探索使用
- 必须堆叠
subject_
列,以将数据帧转换为整齐的格式,这可以通过.stack
完成
import pandas as pd
import seaborn as sns
# convert the dataframe into a long form with stack
df_long = df.set_index(['identifier', 'treatment', 'treatment_timing']).stack().reset_index().rename(columns={'level_3': 'subject', 0: 'vals'})
# sort by subject
df_long = df_long.sort_values('subject').reset_index(drop=True)
# display(df_long.head())
identifier treatment treatment_timing subject vals
0 A untreated pre subject_A 1.3
1 D treated post subject_A 1.6
2 C untreated post subject_A 0.5
3 B treated pre subject_A 0.0
4 D treated post subject_B 0.0
# plot with displot
sns.displot(data=df_long, row='subject', col='treatment', x='vals', hue='treatment_timing', bins=10)