无法在 Streamlit 中显示 SHAP 文本可视化

Cannot display SHAP text visualization in Streamlit

我正在尝试构建我的 NLP 项目的仪表板。因此,我使用 BERT 模型进行预测,使用 SHAP 包进行可视化,使用 Streamlit 来创建仪表板:

tokenizer = AutoTokenizer.from_pretrained(model_name_cla)
model = AutoModelForSequenceClassification.from_pretrained(model_name_cla)
labels = ['1- Tarife','2- Dateneingabe','3- Bestätigungsmail','4- Kundenbetreuung','5- Aufwand vom Vergleich bis Abschluss',
          '6- After-sales Wechselprozess','7 - Werbung/VX Kommunikation','8 - Sonstiges','9 - Nicht auswertbar']

def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x])
    attention_mask = (tv!=0).type(torch.int64)
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

text = ['This is just a test']

# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer, output_names=labels)

shap_values = explainer(text, fixed_context=1)

shap.plots.text(shap_values)

代码在我的 jupyter notebook 中运行良好,但是当我尝试将它作为 .py 文件执行时,没有任何反应。它既不显示任何内容也不抛出错误。我的控制台在执行时只是 returns 这个:

<IPython.core.display.HTML object> >

如何在 streamlit 中显示图表?

这可以用 Streamlit Components 和最新的 SHAP v0.36+(定义新的 getjs 方法)可视化,绘制 JS SHAP plots

(some plots like summary_plot are actually Matplotlib and can be plotted with st.pyplot)

import streamlit as st
import streamlit.components.v1 as components

def st_shap(plot, height=None):
    shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
    components.html(shap_html, height=height)

st_shap(shap.plots.text(shap_values),400)

在此处 streamlit 中找到关于可视化 Shap 的更详细讨论 - Display SHAP diagrams with Streamlit

您可以捕获形状图的输出并使用 components.html

进行渲染
import streamlit.components.v1 as components
from IPython.core.interactiveshell import InteractiveShell
from IPython.utils import capture

def st_plot_text_shap(shap_val, height=None)
    InteractiveShell().instance()
    with capture.capture_output() as cap: 
        shap.plots.text(shap_val)
    components.html(cap.outputs[1].data['text/html'], height=height scrolling=True)