使用张量流 "saved model" api 对 java 与 python 中加载的模型进行错误预测
Getting wrong predictions on a model loaded in java vs python using the tensorflow "saved model" api
我正在尝试在 Java 中加载一个模型,该模型在 python 中训练并使用保存的模型 api (from tensorflow.python.saved_model
).
我可以在单独的 Python 脚本和 Java 中加载它,但是 Java 版本中的预测是错误的。
我用一个简单的模型编写了一个快速示例项目来演示 "bug"(我希望我的误解)。
Python: OrTraining.py
使用保存的模型训练后保存模型 Api。
builders = saved_model_builder.SavedModelBuilder(export_path)
builders.add_meta_graph_and_variables(sess, ["or"], signature_def_map={
"predict": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"images": x_placeholder},
outputs= {"scores": hypothesis_function})
})
builders.save()
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrTraining.py
Python: OrLoadSavedModel.py
使用保存的模型在单独的脚本中加载模型 Api。
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["or"], "orTrainingModels")
graph = tf.get_default_graph()
print(graph.get_operations())
x_placeholder = graph.get_tensor_by_name("or_inputs:0")
hypothesis_function = graph.get_tensor_by_name("hypothesis_output:0")
# sess.run("init")
print(sess.run(hypothesis_function, feed_dict={x_placeholder: np.array([
np.array([1, 0]),
np.array([0, 1]),
np.array([0, 0]),
np.array([1, 1]),
])}))
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrLoadSavedModel.py
Java: OrLoadSavedModel.java
加载
SavedModelBundle savedModelBundle = SavedModelBundle.load("./orTrainingModels", "or");
Session session = savedModelBundle.session();
运行
Tensor result = session.runner()
.feed("or_inputs", tensorInput)
.fetch("hypothesis_output")
.run().get(0);
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/src/main/java/OrLoadSavedModel.java
java 版本和 python 版本加载和 运行 图表都没有问题,但 java 版本没有输出正确的预测。
起初我以为是因为 weights/bias 没有被加载,但我能够 "run" weights/bias 在 java 版本中的操作和看到它具有我在训练后在 python 脚本中看到的正确权重。
检查权重 java (https://github.com/JsFlo/DebuggingSavedModelJava)
Tensor result = session.runner()
.fetch("da_weights")
.run().get(0);
事实证明这是我输入数据的方式的问题。Tensorflow 不喜欢创建 Boxed Types
的张量(整数与整数/浮点数与浮点数)并且有检查查看您是否尝试传递盒装类型,但似乎检查不那么全面。
@Test
public void testCreateFromArrayOfBoxed() {
Integer[] vector = new Integer[] {1, 2, 3, 4};
try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
fail("Tensor.create() should fail because it was given an array of boxed values");
} catch (IllegalArgumentException e) {
// The expected exception
}
}
这是我的问题的一个例子:
Float[] input = new Float[]{0f, 1f};
Tensor tensorOutput = Tensor.create(input);
float[] floatOutput= new float[2];
tensorOutput.copyTo(floatOutput);
println(Arrays.toString(floatOutput)); // -7.377E30, -7.377E30
float[] input = new float[]{0f, 1f};
Tensor tensorOutput = Tensor.create(input);
float[] floatOutput= new float[2];
tensorOutput.copyTo(floatOutput);
println(Arrays.toString(floatOutput)); // 0, 1
我正在尝试在 Java 中加载一个模型,该模型在 python 中训练并使用保存的模型 api (from tensorflow.python.saved_model
).
我可以在单独的 Python 脚本和 Java 中加载它,但是 Java 版本中的预测是错误的。
我用一个简单的模型编写了一个快速示例项目来演示 "bug"(我希望我的误解)。
Python: OrTraining.py
使用保存的模型训练后保存模型 Api。
builders = saved_model_builder.SavedModelBuilder(export_path)
builders.add_meta_graph_and_variables(sess, ["or"], signature_def_map={
"predict": tf.saved_model.signature_def_utils.predict_signature_def(
inputs= {"images": x_placeholder},
outputs= {"scores": hypothesis_function})
})
builders.save()
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrTraining.py
Python: OrLoadSavedModel.py
使用保存的模型在单独的脚本中加载模型 Api。
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["or"], "orTrainingModels")
graph = tf.get_default_graph()
print(graph.get_operations())
x_placeholder = graph.get_tensor_by_name("or_inputs:0")
hypothesis_function = graph.get_tensor_by_name("hypothesis_output:0")
# sess.run("init")
print(sess.run(hypothesis_function, feed_dict={x_placeholder: np.array([
np.array([1, 0]),
np.array([0, 1]),
np.array([0, 0]),
np.array([1, 1]),
])}))
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrLoadSavedModel.py
Java: OrLoadSavedModel.java
加载
SavedModelBundle savedModelBundle = SavedModelBundle.load("./orTrainingModels", "or");
Session session = savedModelBundle.session();
运行
Tensor result = session.runner()
.feed("or_inputs", tensorInput)
.fetch("hypothesis_output")
.run().get(0);
https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/src/main/java/OrLoadSavedModel.java
java 版本和 python 版本加载和 运行 图表都没有问题,但 java 版本没有输出正确的预测。
起初我以为是因为 weights/bias 没有被加载,但我能够 "run" weights/bias 在 java 版本中的操作和看到它具有我在训练后在 python 脚本中看到的正确权重。
检查权重 java (https://github.com/JsFlo/DebuggingSavedModelJava)
Tensor result = session.runner()
.fetch("da_weights")
.run().get(0);
事实证明这是我输入数据的方式的问题。Tensorflow 不喜欢创建 Boxed Types
的张量(整数与整数/浮点数与浮点数)并且有检查查看您是否尝试传递盒装类型,但似乎检查不那么全面。
@Test
public void testCreateFromArrayOfBoxed() {
Integer[] vector = new Integer[] {1, 2, 3, 4};
try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
fail("Tensor.create() should fail because it was given an array of boxed values");
} catch (IllegalArgumentException e) {
// The expected exception
}
}
这是我的问题的一个例子:
Float[] input = new Float[]{0f, 1f};
Tensor tensorOutput = Tensor.create(input);
float[] floatOutput= new float[2];
tensorOutput.copyTo(floatOutput);
println(Arrays.toString(floatOutput)); // -7.377E30, -7.377E30
float[] input = new float[]{0f, 1f};
Tensor tensorOutput = Tensor.create(input);
float[] floatOutput= new float[2];
tensorOutput.copyTo(floatOutput);
println(Arrays.toString(floatOutput)); // 0, 1