如何在 python 中的散点图图例中仅显示颜色编码

How to show only color coding in the legend of my plotly scatterplot in python

我正在用 plotly.express 散点图函数绘制一些 PCA,并按区域(颜色)和品种(符号)对样本进行编码。当我绘制它时,图例以符号和颜色的组合向我展示了所有 67 个不同品种。有没有办法只显示颜色类别?

我的数据是这样的:

PC1 PC2 PC3 Breed Region
Sample1 value value value breed1 Region1
Sample2 value value value breed2 Region1
Sample3 value value value breed3 Region2
Sample4 value value value breed1 Region1

现在我的代码只是基本命令:

fig=px.scatter(pca, x="PC2",y="PC1", color="Region", symbol="Breed", labels={
    "PC2":"PC2-{}%".format(eigen[1]),
    "PC1":"PC1-{}%".format(eigen[0])
})

fig.layout.update(showlegend=True)
fig['layout']['height'] = 800
fig['layout']['width'] = 800
fig.show() 

有什么想法吗?

您可以添加这些行:

region_lst = []

for trace in fig["data"]:
    trace["name"] = trace["name"].split(",")[0]
    
    if trace["name"] not in region_lst and trace["marker"]['symbol'] == 'circle':
        trace["showlegend"] = True
        region_lst.append(trace["name"])
    else:
        trace["showlegend"] = False
        
fig.update_layout(legend_title = "region")  
fig.show()

在添加代码行之前:

添加代码后:

我使用了这个数据框:

import plotly.express as px

df = px.data.medals_long()
fig = px.scatter(df, y="nation", x="count", color="medal", symbol="count")
  • 指定 colorsymbol 导致图例是相应列中值的组合
  • 要让图例仅是一列中的值,请更改为仅使用 颜色
  • 要将第二列表示为符号,请更改每条轨迹以使用符号列表
  • 根据您的描述合成了数据

完整代码

import pandas as pd
import numpy as np
import plotly.express as px
from plotly.validators.scatter.marker import SymbolValidator

eigen = [0.5, 0.7]

# simulate data
n = 1000
pca = pd.DataFrame(
    {
        **{f"PC{c}": np.random.uniform(1, 5, n) for c in range(1, 4, 1)},
        **{
            "Breed": np.random.choice([f"breed{x}" for x in range(67)], n),
            "Region": np.random.choice([f"Region{x}" for x in range(10)], n),
        },
    }
)

# just color by Region
fig = px.scatter(
    pca,
    x="PC2",
    y="PC1",
    color="Region",
    labels={"PC2": "PC2-{}%".format(eigen[1]), "PC1": "PC1-{}%".format(eigen[0])},
)

# build dict that maps as Breed to a symbol
symbol_map = {
    t: s for t, s in zip(np.sort(pca["Breed"].unique()), SymbolValidator().values[2::3])
}

# for each trace update marker symbol to list of symbols that correspond to Breed
for t, s in zip(
    fig.data, pca.groupby("Region")["Breed"].agg(lambda x: [symbol_map[v] for v in x])
):
    t.update(marker_symbol=s)


fig