用于情感分析的 CNN 使用 Android 的 TFLearn 模型对用户输入进行分类
CNN for Sentiment Analysis using TFLearn model for Android to classify user input
我有一个用于文本分类的 CNN 模型,它使用手套的预训练嵌入。我已经冻结了为推理优化的图表,并在 android 工作室中使用它。问题是当我尝试将权重传递给模型进行推理时。我有一个 JSON 文件,其中包含单词和嵌入之间的键值对,我用它来根据用户输入的文本创建嵌入输入。我已经可以从 JSON 文件,但是当我尝试将其输入图表进行推理时,它给了我以下错误:
java.lang.IllegalArgumentException: indices[0,3891] = -2 is not in [0,
7459)
[[Node: EmbeddingLayer/embedding_lookup = Gather[Tindices=DT_INT32,
Tparams=DT_FLOAT, _class=["loc:@EmbeddingLayer/W"],
validate_indices=false,
_device="/job:localhost/replica:0/task:0/device:CPU:0"]
(EmbeddingLayer/W/read, EmbeddingLayer/Cast)]]
Android 代码在我的 GitHub
https://github.com/sushiboo/testNN1
给我带来问题的主要代码是分类方法:
private void classify(float[] input){
TFInference = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
TFInference.feed(INPUT_NODE, input, 1, input.length);
TFInference.run(OUTPUT_NODES);
float[] resu = new float[2];
TFInference.fetch(OUTPUT_NODE, resu);
tvResult.setText("Programmer: " + Float.toString(resu[0]) + "\n Construction" + Float.toString(resu[1]));
Log.e("Result: ", Float.toString(resu[0]));
}
问题出在
TFInference.run(OUTPUT_NODES);
在错误消息中,数字“7459”表示嵌入层的输入维度。
我对这里发生的事情感到很困惑,但我知道索引[0,3891] = -2 在其中起着一定的作用。
问题出在模特身上。我已经修复了这个,现在我陷入了另一个错误。
我有一个用于文本分类的 CNN 模型,它使用手套的预训练嵌入。我已经冻结了为推理优化的图表,并在 android 工作室中使用它。问题是当我尝试将权重传递给模型进行推理时。我有一个 JSON 文件,其中包含单词和嵌入之间的键值对,我用它来根据用户输入的文本创建嵌入输入。我已经可以从 JSON 文件,但是当我尝试将其输入图表进行推理时,它给了我以下错误:
java.lang.IllegalArgumentException: indices[0,3891] = -2 is not in [0,
7459)
[[Node: EmbeddingLayer/embedding_lookup = Gather[Tindices=DT_INT32,
Tparams=DT_FLOAT, _class=["loc:@EmbeddingLayer/W"],
validate_indices=false,
_device="/job:localhost/replica:0/task:0/device:CPU:0"]
(EmbeddingLayer/W/read, EmbeddingLayer/Cast)]]
Android 代码在我的 GitHub https://github.com/sushiboo/testNN1
给我带来问题的主要代码是分类方法:
private void classify(float[] input){
TFInference = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE);
TFInference.feed(INPUT_NODE, input, 1, input.length);
TFInference.run(OUTPUT_NODES);
float[] resu = new float[2];
TFInference.fetch(OUTPUT_NODE, resu);
tvResult.setText("Programmer: " + Float.toString(resu[0]) + "\n Construction" + Float.toString(resu[1]));
Log.e("Result: ", Float.toString(resu[0]));
}
问题出在
TFInference.run(OUTPUT_NODES);
在错误消息中,数字“7459”表示嵌入层的输入维度。
我对这里发生的事情感到很困惑,但我知道索引[0,3891] = -2 在其中起着一定的作用。
问题出在模特身上。我已经修复了这个,现在我陷入了另一个错误。