无法使用 ktrain 模型在 Front end streamlit 上进行预测,请提供有关如何为预测功能提供输入的建议

Can not predict on Front end streamlit with ktrain model, kindly provide suggestions about how to provide input for predict function

无法使用ktrain模型在Front end streamlit上进行预测,请提供有关如何为预测功能提供输入的建议。

基本上我想了解如何为我保存的 ktrain 回归模型提供输入,以便我可以将它合并到 streamlit 网络应用程序按钮中。

我已经尝试将数组、列表和数据框作为参数放入 .predict 函数中,但似乎仍然遗漏了一些东西。在点击预测按钮时出现值错误。

import streamlit as st
from PIL import Image 
import pandas as pd
from tensorflow import keras
model = keras.models.load_model("predictor.h5")


st.write("This is an application to calculate Employee Mental Fatigue Score")
image = Image.open("IMG_2605.jpeg")
st.image(image, use_column_width=True)

WFH_Setup_Available = st.text_input("is work from home enabled for you?")
Designation =st.text_input("what is your designation?")
Average_hours_worked_per_day = st.text_input("how many hours you work on an average per day?")
Employee_satisfaction_score = st.text_input("Please enter your satisfaction score on scale of 10")
data = ['WFH_Setup_Available', 'Designation', 'Average_hours_worked_per_day' , 'Employee_satisfaction_score']


def mental_fatigue_score(WFH_Setup_Available, Designation, Average_hours_worked_per_day, Employee_satisfaction_score):
  prediction = model.predict([[WFH_Setup_Available, Designation, Average_hours_worked_per_day, Employee_satisfaction_score]])
  print(prediction)
  return prediction


if st.button("Predict"):
  result= mental_fatigue_score(WFH_Setup_Available, Designation, Average_hours_worked_per_day, Employee_satisfaction_score)
  st.success('The output is {}'.format(result))

请建议如何为 streamlit 网络应用程序的 .predict 函数提供输入。 我已经使用 ktrain 回归器训练了预测器。

通过将 ktrain 模型保存为

自己解决了
predictor.save('predictor')
predictor = ktrain.load_predictor('predictor')

当我保存为预测器时,它会创建一个文件夹,其中我有一个 tf_mode.h5 & tf_model.preproc.

这比我预期的要容易。

进一步的火车输入应该是像下面这样的数据框-

data = {'WFH_Setup_Available':WFH_Setup_Available,'Designation':Designation, 'Company_Type':Company_Type, 
        'Average_hours_worked_per_day': Average_hours_worked_per_day, 'Employee_satisfaction_score': Employee_satisfaction_score}

data = pd.DataFrame([data])