matplotlib:图例为 class 字符串的散点图

matplotlib: scatter plot with legend as string of class

绘制二进制分类的散点分布['cat', 'dog']

X, y = make_classification(n_samples=1000, n_features=2, n_redundant=0,
    n_clusters_per_class=1, weights=[0.9], flip_y=0, random_state=1,)

counter = Counter(y)

for label, _ in counter.items():
    row_ix = np.where(y == label)[0]
    plt.scatter(X[row_ix, 0], X[row_ix, 1], label=label)
plt.legend()
plt.show()

输出:

我想用 catdog 替换 01 图例。我来了:

for label, _ in counter.items():
    row_ix = np.where(y == label)[0]
    plt.scatter(X[row_ix, 0], X[row_ix, 1], label=['cat', 'dog'])
plt.legend()
plt.show()

输出:

您需要在 for 循环的每次迭代中更改图例标签的值,一种可能性是使用 zip

for item, animal in zip(counter.items(), ['cat', 'dog']):
    row_ix = np.where(y == item[0])[0]
    plt.scatter(X[row_ix, 0], X[row_ix, 1], label=animal)
plt.legend()
plt.show()