Seaborn 为多个配对图设置独特的分类颜色

Seaborn set color for unique categorical over several pair-plots

我正在使用 seabornt-SNE 来可视化 class separability/overlap 并且在我的数据集中包含五个 classes。因此,我的情节是一个 2x2 的子情节。我使用了生成下图的以下函数。

def pair_plot_tsne(df):

    tsne = TSNE(verbose=1, random_state=234) 

    df1 = df[(df['mode'] != 'car') & (df['mode'] != 'bus')]
    tsne1 = tsne.fit_transform(df1[cols].values) # cols - df's columns list
    df1['tsne_one'] = tsne1[:, 0]
    df1['tsne-two'] = tsne1[:, 1]

    df2 = df[(df['mode'] != 'foot') & (df['mode']!= 'bus')]
    tsne2 = tsne.fit_transform(df2[cols].values)
    df2['tsne_one'] = tsne2[:, 0]
    df2['tsne-two'] = tsne2[:, 1]

    df3 = df[df['mode'] != 'car']
    tsne3 = tsne.fit_transform(df3[cols].values)
    df3['tsne_one'] = tsne3[:, 0]
    df3['tsne-two'] = tsne3[:, 1]

    df4 = df[df['mode'] != 'foot']
    tsne4 = tsne.fit_transform(df4[cols].values)
    df4['tsne_one'] = tsne4[:, 0]
    df4['tsne-two'] = tsne4[:, 1]

    #create figure
    f = plt.figure(figsize=(16,4))

    ax1 = plt.subplot(2, 2, 1)
    sns.scatterplot( #df1 has 3 classes, so 3 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df1, palette = sns.color_palette('hls', 3), 
        legend='full', alpha = 0.7, ax = ax1 )

    ax2 = plt.subplot(2, 2, 2)
    sns.scatterplot( #df2 has 3 classes, so 3 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df2, palette = sns.color_palette('hls', 3), 
        legend='full', alpha = 0.7, ax = ax2 )

    ax3 = plt.subplot(2, 2, 3)
    sns.scatterplot( #df3 has 4 classes, so 4 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df3, palette = sns.color_palette('hls', 4), 
        legend='full', alpha = 0.7, ax = ax3 )

    ax4 = plt.subplot(2, 2, 4)
    sns.scatterplot( #df4 has 4 classes, so 4 colors
        x ='tsne_one', y='tsne-two', hue = 'mode', data = df4, palette = sns.color_palette('hls', 4),
        legend='full', alpha = 0.7, ax = ax4 )

    return f, ax1, ax2, ax3, ax4

因为我在每个子图中绘制数据集的一个子集,所以我希望每个 class 的颜色在出现的任何图中都保持一致。对于 class,car 模式的 blue 颜色出现在任何子图中,bus 模式出现的 bus 颜色出现,等等...

就像现在一样,footsubplot(2, 2, 1)中是红色的,carsubplot(2, 2, 2)中也被读取了,尽管其余的是一致的。

对于这个用例,seaborn 允许将字典作为调色板。字典将为每个色调值分配一种颜色。

以下是如何为您的数据创建此类字典的示例:

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

df1 = pd.DataFrame({'tsne_one': np.random.randn(10),
                    'tsne-two': np.random.randn(10),
                    'mode': np.random.choice(['foot', 'metro', 'bike'], 10)})
df2 = pd.DataFrame({'tsne_one': np.random.randn(10),
                    'tsne-two': np.random.randn(10),
                    'mode': np.random.choice(['car', 'metro', 'bike'], 10)})
df3 = pd.DataFrame({'tsne_one': np.random.randn(10),
                    'tsne-two': np.random.randn(10),
                    'mode': np.random.choice(['foot', 'bus', 'metro', 'bike'], 10)})
df4 = pd.DataFrame({'tsne_one': np.random.randn(10),
                    'tsne-two': np.random.randn(10),
                    'mode': np.random.choice(['car', 'bus', 'metro', 'bike'], 10)})
modes = pd.concat([df['mode'] for df in (df1, df2, df3, df4)], ignore_index=True).unique()
colors = sns.color_palette('hls', len(modes))
palette = {mode: color for mode, color in zip(modes, colors)}

fig, axs = plt.subplots(2, 2, figsize=(12,6))
for df, ax in zip((df1, df2, df3, df4), axs.flatten()):
    sns.scatterplot(x='tsne_one', y='tsne-two', hue='mode', data=df, palette=palette, legend='full', alpha=0.7, ax=ax)

plt.tight_layout()
plt.show()