Plotly:如何将回归结果等数据嵌入到图例中?

Plotly: How to embed data like regression results into legend?

我正在使用 plotly 作为线性回归模型,我试图将 OLS 趋势线的结果作为图表图例嵌入。当我将鼠标悬停在线性回归线上时,我可以看到线性回归的详细信息,但我想将这些结果作为图表图例始终显示。

有没有办法做到这一点? 这是我的代码:

# Importing plotly dependency
import plotly.express as px

#Ploting the graph
fig = px.scatter(df_linear, x="Days_ct", y="Conf_ct", trendline="ols")
fig.update_traces(name = "OLS trendline")


fig.update_layout(template="ggplot2",title_text = '<b>Linear Regression Model</b>',
                  font=dict(family="Arial, Balto, Courier New, Droid Sans",color='black'), showlegend=True)
fig.update_layout(
    legend=dict(
        x=0.01,
        y=.98,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=12,
            color="Black"
        ),
        bgcolor="LightSteelBlue",
        bordercolor="dimgray",
        borderwidth=2
    ))
fig.show()

根据您的设置和一些合成数据,您可以使用:

model = px.get_trendline_results(fig)
alpha = model.iloc[0]["px_fit_results"].params[0]
beta = model.iloc[0]["px_fit_results"].params[1]

然后将这些发现包含在您的图例中,并直接使用以下方法进行必要的布局调整:

fig.data[0].name = 'observations'
fig.data[0].showlegend = True
fig.data[1].name = fig.data[1].name  + ' y = ' + str(round(alpha, 2)) + ' + ' + str(round(beta, 2)) + 'x'
fig.data[1].showlegend = True

地块 1:

编辑:R 平方

根据您的评论,我将向您展示如何从回归分析中包含其他感兴趣的值。然而,在 图例 中继续包含估计值已经没有多大意义了。然而,这正是以下添加的作用:

rsq = model.iloc[0]["px_fit_results"].rsquared
fig.add_trace(go.Scatter(x=[100], y=[100],
                         name = "R-squared" + ' = ' + str(round(rsq, 2)),
                         showlegend=True,
                         mode='markers',
                         marker=dict(color='rgba(0,0,0,0)')
                         ))

图 2:图例中包含 R 平方

包含合成数据的完整代码:

import plotly.graph_objects as go
import plotly.express as px
import statsmodels.api as sm
import pandas as pd
import numpy as np
import datetime

# data
np.random.seed(123)
numdays=20

X = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()
Y = (np.random.randint(low=-20, high=20, size=numdays).cumsum()+100).tolist()

df_linear = pd.DataFrame({'Days_ct': X, 'Conf_ct':Y})

#Ploting the graph
fig = px.scatter(df_linear, x="Days_ct", y="Conf_ct", trendline="ols")
fig.update_traces(name = "OLS trendline")



fig.update_layout(template="ggplot2",title_text = '<b>Linear Regression Model</b>',
                  font=dict(family="Arial, Balto, Courier New, Droid Sans",color='black'), showlegend=True)
fig.update_layout(
    legend=dict(
        x=0.01,
        y=.98,
        traceorder="normal",
        font=dict(
            family="sans-serif",
            size=12,
            color="Black"
        ),
        bgcolor="LightSteelBlue",
        bordercolor="dimgray",
        borderwidth=2
    ))

# retrieve model estimates
model = px.get_trendline_results(fig)
alpha = model.iloc[0]["px_fit_results"].params[0]
beta = model.iloc[0]["px_fit_results"].params[1]

# restyle figure
fig.data[0].name = 'observations'
fig.data[0].showlegend = True
fig.data[1].name = fig.data[1].name  + ' y = ' + str(round(alpha, 2)) + ' + ' + str(round(beta, 2)) + 'x'
fig.data[1].showlegend = True

# addition for r-squared
rsq = model.iloc[0]["px_fit_results"].rsquared
fig.add_trace(go.Scatter(x=[100], y=[100],
                         name = "R-squared" + ' = ' + str(round(rsq, 2)),
                         showlegend=True,
                         mode='markers',
                         marker=dict(color='rgba(0,0,0,0)')
                         ))

fig.show()