如何沿着我的分类彩色散点图获得图例?

How do I get a legend along my categorical colored scatter plot?

我使用 mathplotlib 创建了一个散点图,并使用数据框中的一列对其进行了着色。现在,我想添加一个图例,以明确什么颜色代表什么数据。但是,简单地添加不带标签的 plt.legend() 并不能解决问题,而在我的 plt.scatter 命令中添加标签也不会。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib notebook

color = {
    "Africa" : "red",
    "Americas" : "green",
    "Eastern Mediterranean": "blue",
    "Europe" : "yellow",
    "South-East Asia": "black",
    "Western Pacific" : "orange"
}

data.columns = ['Country', 'GDP', 'Region', 'Air pollution (ug/m3)']
data['Color'] = data['Region'].map(color)

plt.scatter(data['GDP'], data['Air pollution (ug/m3)'], picker= 0, c = data['Color'], label = data['Region'])
plt.legend()
    
def onpick(event):
    origin = data.iloc[event.ind[0]]['Country']
    plt.gca().set_title('Selected item came from {}'.format(origin))

plt.gcf().canvas.mpl_connect('pick_event', onpick)

目前看起来是这样的:

但是,我希望图例看起来像 color 字典,颜色是要点,后面是 Region。我最好怎么做?

如果您对使用这些颜色不感兴趣,您可以像在这段代码中那样简单地使用 sns.scatterplot,而无需映射每种颜色:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from random import sample
import seaborn as sns

N = 100
data = pd.DataFrame({'GDP': np.random.random(N),
                     'Air pollution (ug/m3)': np.random.random(N),
                     'Region': sample(['Africa',
                                       'Americas',
                                       'Eastern Mediterranean',
                                       'Europe',
                                       'South-East Asia',
                                       'Western Pacific']*N, N)})

sns.scatterplot(data = data,
                x = 'GDP',
                y = 'Air pollution (ug/m3)',
                hue = 'Region')
plt.legend(bbox_to_anchor = (1.05, 0.98), loc = 'upper left')

plt.show()

否则,如果你想保持你的颜色,你可以重新定义循环仪:

import matplotlib.pyplot as plt
from random import sample
import seaborn as sns
from cycler import cycler

N = 100
data = pd.DataFrame({'GDP': np.random.random(N),
                     'Air pollution (ug/m3)': np.random.random(N),
                     'Region': sample(['Africa',
                                       'Americas',
                                       'Eastern Mediterranean',
                                       'Europe',
                                       'South-East Asia',
                                       'Western Pacific']*N, N)})

default_cycler = cycler(color=['red', 'green', 'blue', 'yellow', 'black', 'orange'])
plt.rc('axes', prop_cycle=default_cycler)

sns.scatterplot(data = data,
                x = 'GDP',
                y = 'Air pollution (ug/m3)',
                hue = 'Region')
plt.legend(bbox_to_anchor = (1.05, 0.98), loc = 'upper left')

plt.show()


关于交互性,如公开的那样

Just as in any other case, you define the picker argument and connect the callback function

你的情况:

sns.scatterplot(data = data,
                x = 'GDP',
                y = 'Air pollution (ug/m3)',
                hue = 'Region',
                picker = 4)
plt.legend(bbox_to_anchor = (1.05, 0.98), loc = 'upper left')

def onpick(event):
    origin = data.iloc[event.ind[0]]['Country']
    plt.gca().set_title('Selected item came from {}'.format(origin))

plt.gcf().canvas.mpl_connect('pick_event', onpick)