初始化 NLClassifier 时出错:输入张量的类型不匹配 serving_default_input_type_ids:0。请求了 STRING,得到了 INT32

Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32

我正在尝试学习如何将一些机器学习的东西用于 Android。我让 Text Classification demo 工作并且似乎工作正常。于是我尝试创建自己的模型。

我用来创建自己模型的代码是这样的:

import numpy as np
import os

from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader

import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')

spec = model_spec.get('mobilebert_classifier')

train_data = DataLoader.from_csv(
    filename='/path to file/train.csv',
    text_column='sentence',
    label_column='label',
    model_spec=spec,
    is_training=True)

model = text_classifier.create(train_data, model_spec=spec, epochs=10)

model.export(export_dir='average_word_vec')

代码看起来 运行 没问题,它为我创建了一个 model.tflite 文件。然后我用我的替换了演示 tflite 文件。但是当我 运行 演示时,出现以下错误:

 java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access0(NLClassifier.java:67)
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createHandle(NLClassifier.java:223)
        at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
        at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
        at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
        at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$MainActivity(MainActivity.java:67)
        at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)

我错过了什么?

在您的代码中,您训练了一个 MobileBERT 模型,但保存到 average_word_vec 的路径? 规格 = model_spec.get('mobilebert_classifier') model.export(export_dir='average_word_vec')

一种可能是:你使用average_word_vec的模型,但是添加了MobileBERT元数据,因此预处理不匹配。

您能按照 Model Maker 教程再试一次吗? https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb 确保更改导出路径。