在 matplotlib、scatter 中使用 for 循环时如何自定义颜色图例

how to customize color legend when using for loop in matplotlib, scatter

我想绘制一个 3D 散点图,其中数据按组着色。这是数据示例:

aa=pd.DataFrame({'a':[1,2,3,4,5],
                 'b':[2,3,4,5,6],
                 'c':[1,3,4,6,9],
                 'd':[0,0,1,2,3],
                 'e':['abc','sdf','ert','hgf','nhkm']})

这里,a、b、c分别是x轴、y轴、z轴。 e 是散点图中显示的文本。我需要对数据进行分组并显示不同的颜色。

这是我的代码:

fig = plt.figure()
ax = fig.gca(projection='3d')
zdirs = aa.loc[:,'e'].__array__()
xs = aa.loc[:,'a'].__array__()
ys = aa.loc[:,'b'].__array__()
zs = aa.loc[:,'c'].__array__()
colors = aa.loc[:,'d'].__array__()
colors1=np.where(colors==0,'grey',
                 np.where(colors==1,'yellow',
                          np.where(colors==2,'green',
                                   np.where(colors==3,'pink','red'))))
for i in range(len(zdirs)): #plot each point + it's index as text above
    ax.scatter(xs[i],ys[i],zs[i],color=colors1[i])
    ax.text(xs[i],ys[i],zs[i],  '%s' % (str(zdirs[i])), size=10, zorder=1, color='k')
ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('c')
plt.show()

但我不知道如何在情节上加上传说。我希望我的传说是这样的:

颜色和数字应匹配并订购。

谁能帮我定制颜色条?

首先,我冒昧地减少了您的代码:

  • 我建议创建一个 ListedColormap 来映射整数->颜色,它允许您通过 c=aa['d'] 传递颜色列(注意它是 c=,而不是 color=! )
  • 这里不需要使用__array__(),在下面的代码中可以直接使用aa['a']
  • 最后,您可以为 ListedColormap 中的每种颜色添加一个空散点图,然后可以通过 ax.legend()
  • 正确渲染
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

aa=pd.DataFrame({'a':[1,2,3,4,5],
                 'b':[2,3,4,5,6],
                 'c':[1,3,4,6,9],
                 'd':[0,0,1,2,3],
                 'e':['abc','sdf','ert','hgf','nhkm']})

fig = plt.figure()
ax = fig.gca(projection='3d')

cmap = ListedColormap(['grey', 'yellow', 'green', 'pink','red'])
ax.scatter(aa['a'],aa['b'],aa['c'],c=aa['d'],cmap=cmap)

for x,y,z,label in zip(aa['a'],aa['b'],aa['c'],aa['e']):
    ax.text(x,y,z,label,size=10,zorder=1)

# Create a legend through an *empty* scatter plot
[ax.scatter([], [], c=cmap(i), label=str(i)) for i in range(len(aa))]
ax.legend()

ax.set_xlabel('a')
ax.set_ylabel('b')
ax.set_zlabel('c')

plt.show()