初始化 NLClassifier 时出错:输入张量的类型不匹配 serving_default_input_type_ids:0。请求了 STRING,得到了 INT32
Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32
我正在尝试学习如何将一些机器学习的东西用于 Android。我让 Text Classification demo 工作并且似乎工作正常。于是我尝试创建自己的模型。
我用来创建自己模型的代码是这样的:
import numpy as np
import os
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
spec = model_spec.get('mobilebert_classifier')
train_data = DataLoader.from_csv(
filename='/path to file/train.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=True)
model = text_classifier.create(train_data, model_spec=spec, epochs=10)
model.export(export_dir='average_word_vec')
代码看起来 运行 没问题,它为我创建了一个 model.tflite
文件。然后我用我的替换了演示 tflite
文件。但是当我 运行 演示时,出现以下错误:
java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access0(NLClassifier.java:67)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createHandle(NLClassifier.java:223)
at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$MainActivity(MainActivity.java:67)
at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)
我错过了什么?
在您的代码中,您训练了一个 MobileBERT 模型,但保存到 average_word_vec 的路径?
规格 = model_spec.get('mobilebert_classifier')
model.export(export_dir='average_word_vec')
一种可能是:你使用average_word_vec的模型,但是添加了MobileBERT元数据,因此预处理不匹配。
您能按照 Model Maker 教程再试一次吗?
https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb
确保更改导出路径。
我正在尝试学习如何将一些机器学习的东西用于 Android。我让 Text Classification demo 工作并且似乎工作正常。于是我尝试创建自己的模型。
我用来创建自己模型的代码是这样的:
import numpy as np
import os
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker.config import ExportFormat
from tflite_model_maker.text_classifier import AverageWordVecSpec
from tflite_model_maker.text_classifier import DataLoader
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
spec = model_spec.get('mobilebert_classifier')
train_data = DataLoader.from_csv(
filename='/path to file/train.csv',
text_column='sentence',
label_column='label',
model_spec=spec,
is_training=True)
model = text_classifier.create(train_data, model_spec=spec, epochs=10)
model.export(export_dir='average_word_vec')
代码看起来 运行 没问题,它为我创建了一个 model.tflite
文件。然后我用我的替换了演示 tflite
文件。但是当我 运行 演示时,出现以下错误:
java.lang.AssertionError: Error occurred when initializing NLClassifier: Type mismatch for input tensor serving_default_input_type_ids:0. Requested STRING, got INT32.
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.initJniWithByteBuffer(Native Method)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.access0(NLClassifier.java:67)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createHandle(NLClassifier.java:223)
at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromBufferAndOptions(NLClassifier.java:219)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFileAndOptions(NLClassifier.java:175)
at org.tensorflow.lite.task.text.nlclassifier.NLClassifier.createFromFile(NLClassifier.java:150)
at org.tensorflow.lite.examples.textclassification.client.TextClassificationClient.load(TextClassificationClient.java:44)
at org.tensorflow.lite.examples.textclassification.MainActivity.lambda$onStart$MainActivity(MainActivity.java:67)
at org.tensorflow.lite.examples.textclassification.-$$Lambda$MainActivity$eJaQnJq74KcmPEczFE5swJIGydg.run(Unknown Source:2)
我错过了什么?
在您的代码中,您训练了一个 MobileBERT 模型,但保存到 average_word_vec 的路径? 规格 = model_spec.get('mobilebert_classifier') model.export(export_dir='average_word_vec')
一种可能是:你使用average_word_vec的模型,但是添加了MobileBERT元数据,因此预处理不匹配。
您能按照 Model Maker 教程再试一次吗? https://colab.sandbox.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/tutorials/model_maker_text_classification.ipynb 确保更改导出路径。