从 plotly imshow / heatmap 的最后一行和最后一列中删除色标

Remove colorscale from last row and column of a plotly imshow / heatmap

我有一个矩阵,我正在用它绘制下面的图。

正如人们所看到的那样,绘图的总行和列部分主导了色标,这反过来又削弱了其他值的视觉效果。我的目标是从最后一行和最后一列中删除色标,但我不确定该怎么做。

这是我现在的剧情代码

def annotate_matrix_values(df: pd.DataFrame, fig: go.Figure) -> go.Figure:
    df_matrix = df.to_records(index=False)
    for i, r in enumerate(df_matrix):
        for k, c in enumerate(r):
            fig.add_annotation(
                x=k,
                y=i,
                text=f"<b>{str(int(c))}</b>" if not math.isnan(c) else "",
                showarrow=False,
            )
    fig.update_xaxes(side="top")
    fig.update_coloraxes(showscale=False)
    return fig
  

df_final = pd.crosstab(pt[y_axis_name], pt[x_axis_name], margins=True)

fig = px.imshow(
            df_final,
            labels=dict(x=x_title, y=y_title, color="Value"),
            x=x_label,
            y=y_label,
            color_continuous_scale="Purples",
            aspect="auto",
        )
fig = annotate_matrix_values(df_final, fig)

当然我可以在这里设置 margin false 得到下面的图。但是我无法显示总数。

df_final = pd.crosstab(pt[y_axis_name], pt[x_axis_name], margins=False)

所以我的目标是显示总数但不显示任何色标。感谢帮助

我用这个例子解决了这个问题:

import pandas as pd
import plotly.graph_objects as go
import math

df = pd.DataFrame({'animal' : ['fish','fish','fish','monkey','tiger'], 
                   'zoo':['Paris','Grenoble','Paris','Toulose','Paris']})
df_cross = pd.crosstab(df.animal,df.zoo, margins=True)

df_matrix = df_cross.to_records(index=False)

fig = px.imshow(df_cross, color_continuous_scale="Purples")

for i, r in enumerate(df_matrix):
    for k, c in enumerate(r):
        fig.add_annotation(
                x=k,
                y=i,
                text=f"<b>{str(int(c))}</b>" if not math.isnan(c) else "",
                font_color = "white",
                showarrow=False,
            )

fig.show()

添加这些行来解决问题:

fig.layout['coloraxis']['cauto']=False 
fig.layout['coloraxis']['cmax'] = fig.data[0]['z'][:-1,:-1].max()
fig.show()

色阶根据行和列的最大值调整,不包括总行和列的值。总计行和列中大于最大值的值用与最大值相同的颜色着色。

理解问题的目的是从色标中排除总数,我尝试了各种方法,包括子图和覆盖空白图和热图,但没有得到好的结果。最简单的方法是在 EXPRESS 热图中添加一个四边形和注释。数据以 x,y 的形式创建,带有分类变量。热图是使用不包括总计的数据创建的,注释是使用总计作为字符串创建的。

import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd

data = np.random.randint(0,10,60).reshape(12,5)

cols = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday']
rows = ['Jan.','Feb.','Mar.','Apr.','May','June','July','Aug.','Sept.','Oct.','Nov.','Dec.']

df = pd.DataFrame(data, columns=cols, index=rows)
df['month_sum'] = df.sum(axis=1)
df.loc['week_sum'] = df.sum(axis=0)

fig = px.imshow(
    df.iloc[:-1,:-1].values,
    labels=dict(x="Day of Week", y="Year of Month", color="Value"),
    x=cols,
    y=rows,
    color_continuous_scale="Purples",
    aspect="auto",
    text_auto='auto'
        )

for i,r in enumerate(df['month_sum']):
    fig.add_shape(type='rect',
                  x0=4.5, y0=-0.5+i, x1=5.5, y1=0.5+i,
                  line=dict(
                      color='rgb(188,189,220)',
                      width=1,
                  ),
                  fillcolor='white',
                 )
    fig.add_annotation(
        x=5,
        y=i,
        text=str(r),
        showarrow=False
    )

for i,c in enumerate(df.loc['week_sum'].tolist()):
    if i == 5:
        break
    fig.add_shape(type='rect',
                  x0=-0.5+i, y0=12.5, x1=0.5+i, y1=11.5,
                  line=dict(
                      color='rgb(188,189,220)',
                      width=1,
                  ),
                  fillcolor='white',
                 )
    fig.add_annotation(
        x=i,
        y=12.0,
        text=str(c),
        showarrow=False
    )

fig.add_annotation(
    x=1.0,
    y=1.04,
    xref='paper',
    yref='paper',
    text='Month Total',
    showarrow=False,
)

fig.add_annotation(
    x=-0.16,
    y=0.02,
    xref='paper',
    yref='paper',
    text='Week Total',
    showarrow=False,
)

fig.update_xaxes(side="top")
fig.update_layout(
    autosize=False,
    width=600,
    height=600,
    coloraxis=dict(showscale=False)
                 )

fig.show()