在 seaborn catplot 中为每个类别添加从最小点到最大点的水平线

Add horizontal lines from min point to max point per category in seaborn catplot

我正在尝试为选定的客户可视化每个季度的不同类型的“购买”。为了生成此视觉效果,我在 seaborn 中使用了 catplot 功能,但无法添加连接每个购买的水果的水平线。 每行应从每个水果的第一个点开始,到同一个水果的最后一个点结束。关于如何以编程方式执行此操作的任何想法?

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

dta = pd.DataFrame(columns=["Date", "Fruit", "type"], data=[['2017-01-01','Orange', 
'FP'], ['2017-04-01','Orange', 'CP'], ['2017-07-01','Orange', 'CP'], 
['2017-10-08','Orange', 'CP'],['2017-01-01','Apple', 'NP'], ['2017-04-01','Apple', 'CP'], 
['2017-07-01','Banana', 'NP'], ['2017-10-08','Orange', 'CP']
                                                        ])
dta['quarter'] = pd.PeriodIndex(dta.Date, freq='Q')

sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)

plt.show()

这是结果:

.

如何添加单独的水平线,每条水平线都连接购买橙子和苹果的点?

您只需按如下方式为图表启用水平网格:

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

dta = pd.DataFrame(
    columns=["Date", "Fruit", "type"],
    data=[
        ["2017-01-01", "Orange", "FP"],
        ["2017-04-01", "Orange", "CP"],
        ["2017-07-01", "Orange", "CP"],
        ["2017-10-08", "Orange", "CP"],
        ["2017-01-01", "Apple", "NP"],
        ["2017-04-01", "Apple", "CP"],
        ["2017-07-01", "Banana", "NP"],
        ["2017-10-08", "Orange", "CP"],
    ],
)


dta["quarter"] = pd.PeriodIndex(dta.Date, freq="Q")
sns.catplot(x="quarter", y="Fruit", hue="type", kind="swarm", data=dta)
plt.grid(axis='y')
plt.show()

预览

Each line should start at the first dot for each fruit and end at the last dot for the same fruit.

  1. 使用 groupby.ngroup 将四分之一映射到 xtick 位置
  2. 使用groupby.agg找到每个水果的最小和最大xtick端点
  3. 使用ax.hlines绘制从每个水果的最小值到最大值的水平线
df = pd.DataFrame([['2017-01-01', 'Orange', 'FP'], ['2017-04-01', 'Orange', 'CP'], ['2017-07-01', 'Orange', 'CP'], ['2017-10-08', 'Orange', 'CP'], ['2017-01-01', 'Apple', 'NP'], ['2017-04-01', 'Apple', 'CP'], ['2017-07-01', 'Banana', 'NP'], ['2017-10-08', 'Orange', 'CP']], columns=['Date', 'Fruit', 'type'])
df['quarter'] = pd.PeriodIndex(df['Date'], freq='Q')

df = df.sort_values('quarter')                            # sort dataframe by quarter
df['xticks'] = df.groupby('quarter').ngroup()             # map quarter to xtick position
ends = df.groupby('Fruit')['xticks'].agg(['min', 'max'])  # find min and max xtick per fruit

g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
g.axes[0, 0].hlines(ends.index, ends['min'], ends['max']) # plot horizontal lines from each fruit's min to max

详细分类:

  1. catplot 按照它们在数据框中出现的顺序绘制 xtick。示例数据帧已按 quarter 排序,但实际数据帧应明确排序:

    df = df.sort_values('quarter')
    
  2. 使用 groupby.ngroup:

    将四分之一区映射到它们的 xtick 位置
    df['xticks'] = df.groupby('quarter').ngroup()
    
    #          Date   Fruit  type  quarter  xticks
    # 0  2017-01-01  Orange    FP   2017Q1       0
    # 1  2017-04-01  Orange    CP   2017Q2       1
    # 2  2017-07-01  Orange    CP   2017Q3       2
    # 3  2017-10-08  Orange    CP   2017Q4       3
    # 4  2017-01-01   Apple    NP   2017Q1       0
    # 5  2017-04-01   Apple    CP   2017Q2       1
    # 6  2017-07-01  Banana    NP   2017Q3       2
    # 7  2017-10-08  Orange    CP   2017Q4       3
    
  3. 使用 groupby.agg:

    找到最小值和最大值 xticks 以获得每个 Fruit 的端点
    ends = df.groupby('Fruit')['xticks'].agg(['min', 'max'])
    
    #         min  max
    # Fruit           
    # Apple     0    1
    # Banana    2    2
    # Orange    0    3
    
  4. 使用ax.hlines绘制一条从min-endpoint到max-endpoint的每个Fruit的水平线:

    g = sns.catplot(x='quarter', y='Fruit', hue='type', kind='swarm', s=8, data=df)
    ax = g.axes[0, 0]
    ax.hlines(ends.index, ends['min'], ends['max'])