使用 matplotlib.pyplot.plot() 在 Python 的指定容器内绘制的一系列不同数据点的图例中添加单个标签

Adding a single label to the legend for a series of different data points plotted inside a designated bin in Python using matplotlib.pyplot.plot()

我有一个脚本,用于使用 csv 文件绘制红图星团的天文数据。我可以获得其中的数据点,并希望根据它们的红移值使用不同的颜色绘制它们:我根据红移将数据集分为 3 个箱子(0.1-0.2、0.2-0.25、0.25、0.31)。

在我区分数据点属于哪个 bin 后,我的代码出现了问题:我想在图例中有 3 个标签对应于红色、绿色和蓝色数据点,但这没有发生,我也没有知道为什么。我使用的是 plot() 而不是 scatter(),因为我还必须根据同一图中的数据进行最佳拟合。所以所有东西都需要在 1 个数字内。

import numpy as np
import matplotlib.pyplot as py
import csv

z = open("Sheet4CSV.csv","rU")
data = csv.reader(z)
x = []
y = []
ylow = []
yupp = []
xlow = []
xupp = []
redshift = []

for r in data:
    x.append(float(r[2]))
    y.append(float(r[5]))
    xlow.append(float(r[3]))
    xupp.append(float(r[4]))
    ylow.append(float(r[6]))
    yupp.append(float(r[7]))
    redshift.append(float(r[1]))

from operator import sub
xerr_l = map(sub,x,xlow)
xerr_u = map(sub,xupp,x)
yerr_l = map(sub,y,ylow)
yerr_u = map(sub,yupp,y)

py.xlabel("$Original\ Tx\ XCS\ pipeline\ Tx\ keV$")
py.ylabel("$Iterative\ Tx\ pipeline\ keV$")
py.xlim(0,12)
py.ylim(0,12)
py.title("Redmapper Clusters comparison of Tx pipelines")

ax1 = py.subplot(111)

##Problem starts here after the previous line##

for p in redshift:
    for i in xrange(84):
        p=redshift[i]

        if 0.1<=p<0.2:

            ax1.plot(x[i],y[i],color="b", marker='.', linestyle = " ")#, label = "$z < 0.2$")
            exit


        if 0.2<=p<0.25:
            ax1.plot(x[i],y[i],color="g", marker='.', linestyle = " ")#, label="[=12=].2 \leq z < 0.25$")
            exit

        if 0.25<=p<=0.3:
            ax1.plot(x[i],y[i],color="r", marker='.', linestyle = " ")#, label="$z \geq 0.25$")
            exit

##There seems nothing wrong after this point##

py.errorbar(x,y,yerr=[yerr_l,yerr_u],xerr=[xerr_l,xerr_u], fmt= " ",ecolor='magenta', label="Error bars")

cof = np.polyfit(x,y,1)
p = np.poly1d(cof)
l = np.linspace(0,12,100)
py.plot(l,p(l),"black",label="Best fit")
py.plot([0,15],[0,15],"black", linestyle="dotted", linewidth=2.0, label="line $y=x$")
py.grid()

box = ax1.get_position()
ax1.set_position([box.x1,box.y1,box.width, box.height])
py.legend(loc='center left',bbox_to_anchor=(1,0.5))
py.show()

在第一个 'for' 循环中,我已经为列表 'redshift' 中的每个值 'p' 建立了索引,以便可以使用 'if' 语句创建 bin。但是,如果我在 'if' 语句中添加针对每个 py.plot() 散列的标签,则图中绘制的每个数据点 'i' 作为 (x[i ],y[i]) 取标签,我的整个图例总共达到87个标签(包括其他地方代码中提到的3个)!!!!!!!

我基本上每个箱子都需要 1 个标签...

请告诉我在创建垃圾箱和使用 py.plot() 命令后需要做什么...提前致谢:-) 抱歉,由于声誉不佳,我不能 post 我的图片!

csv文件中x、y和redshift列表的数据'appended'如下:

x=[5.031,10.599,10.589,8.548,9.089,8.675,3.588,1.244,3.023,8.632,8.953,7.603,7.513,2.917,7.344,7.106,3.889,7.287,3.367,6.839,2.801,2.316,1.328,6.31,6.19,6.329,6.025,5.629,6.123,5.892,5.438,4.398,4.542,4.624,4.501,4.504,5.033,5.068,4.197,2.854,4.784,2.158,4.054,3.124,3.961,4.42,3.853,3.658,1.858,4.537,2.072,3.573,3.041,5.837,3.652,3.209,2.742,2.732,1.312,3.635,2.69,3.32,2.488,2.996,2.269,1.701,3.935,2.015,0.798,2.212,1.672,1.925,3.21,1.979,1.794,2.624,2.027,3.66,1.073,1.007,1.57,0.854,0.619,0.547]
y=[5.255,10.897,11.045,9.125,9.387,17.719,4.025,1.389,4.152,8.703,9.051,8.02,7.774,3.139,7.543,7.224,4.155,7.416,3.905,6.868,2.909,2.658,1.651,6.454,6.252,6.541,6.152,5.647,6.285,6.079,5.489,4.541,4.634,8.851,4.554,4.555,5.559,5.144,5.311,5.839,5.364,3.18,4.352,3.379,4.059,4.575,3.914,5.736,2.304,4.68,3.187,3.756,3.419,9.118,4.595,3.346,3.603,6.313,1.816,4.34,2.732,4.978,2.719,3.761,2.623,2.1,4.956,2.316,4.231,2.831,1.954,2.248,6.573,2.276,2.627,3.85,3.545,25.405,3.996,1.347,1.679,1.435,0.759,0.677]
redshift = [0.12,0.25,0.23,0.23,0.27,0.26,0.12,0.27,0.17,0.18,0.17,0.3,0.23,0.1,0.23,0.29,0.29,0.12,0.13,0.26,0.11,0.24,0.13,0.21,0.17,0.2,0.3,0.29,0.23,0.27,0.25,0.21,0.11,0.15,0.1,0.26,0.23,0.12,0.23,0.26,0.2,0.17,0.22,0.26,0.25,0.12,0.19,0.24,0.18,0.15,0.27,0.14,0.14,0.29,0.29,0.26,0.15,0.29,0.24,0.24,0.23,0.26,0.29,0.22,0.13,0.18,0.24,0.14,0.24,0.24,0.17,0.26,0.29,0.11,0.14,0.26,0.28,0.26,0.28,0.27,0.23,0.26,0.23,0.19]

处理这样的数值数据,您真的应该考虑使用数值库,例如 numpy。

您的代码中的问题源于一次处理每条记录(坐标 (x,y) 和相应的值 redshift)。您正在为每个点调用 plot,从而为这 84 个数据点中的每一个创建图例。您应该将 "bins" 视为属于同一数据集的数据组并按此处理它们。您可以使用“logical masks”来区分您的 "bins",如下所示。

也不清楚你为什么在每次密谋行动后调用 exit

import numpy as np
import matplotlib.pyplot as plt

x = np.array([5.031,10.599,10.589,8.548,9.089,8.675,3.588,1.244,3.023,8.632,8.953,7.603,7.513,2.917,7.344,7.106,3.889,7.287,3.367,6.839,2.801,2.316,1.328,6.31,6.19,6.329,6.025,5.629,6.123,5.892,5.438,4.398,4.542,4.624,4.501,4.504,5.033,5.068,4.197,2.854,4.784,2.158,4.054,3.124,3.961,4.42,3.853,3.658,1.858,4.537,2.072,3.573,3.041,5.837,3.652,3.209,2.742,2.732,1.312,3.635,2.69,3.32,2.488,2.996,2.269,1.701,3.935,2.015,0.798,2.212,1.672,1.925,3.21,1.979,1.794,2.624,2.027,3.66,1.073,1.007,1.57,0.854,0.619,0.547])
y = np.array([5.255,10.897,11.045,9.125,9.387,17.719,4.025,1.389,4.152,8.703,9.051,8.02,7.774,3.139,7.543,7.224,4.155,7.416,3.905,6.868,2.909,2.658,1.651,6.454,6.252,6.541,6.152,5.647,6.285,6.079,5.489,4.541,4.634,8.851,4.554,4.555,5.559,5.144,5.311,5.839,5.364,3.18,4.352,3.379,4.059,4.575,3.914,5.736,2.304,4.68,3.187,3.756,3.419,9.118,4.595,3.346,3.603,6.313,1.816,4.34,2.732,4.978,2.719,3.761,2.623,2.1,4.956,2.316,4.231,2.831,1.954,2.248,6.573,2.276,2.627,3.85,3.545,25.405,3.996,1.347,1.679,1.435,0.759,0.677])
redshift = np.array([0.12,0.25,0.23,0.23,0.27,0.26,0.12,0.27,0.17,0.18,0.17,0.3,0.23,0.1,0.23,0.29,0.29,0.12,0.13,0.26,0.11,0.24,0.13,0.21,0.17,0.2,0.3,0.29,0.23,0.27,0.25,0.21,0.11,0.15,0.1,0.26,0.23,0.12,0.23,0.26,0.2,0.17,0.22,0.26,0.25,0.12,0.19,0.24,0.18,0.15,0.27,0.14,0.14,0.29,0.29,0.26,0.15,0.29,0.24,0.24,0.23,0.26,0.29,0.22,0.13,0.18,0.24,0.14,0.24,0.24,0.17,0.26,0.29,0.11,0.14,0.26,0.28,0.26,0.28,0.27,0.23,0.26,0.23,0.19])
bin3 = 0.25 <= redshift
bin2 = np.logical_and(0.2 <= redshift, redshift < 0.25)
bin1 = np.logical_and(0.1 <= redshift, redshift < 0.2)

plt.ion()
labels = ("$z < 0.2$", "[=10=].2 \leq z < 0.25$", "$z \geq 0.25$")
colors = ('r', 'g', 'b')
for bin, label, co in zip( (bin1, bin2, bin3), labels, colors):
    plt.plot(x[bin], y[bin], color=co, ls='none', marker='o', label=label)
plt.legend()
plt.show()