Tensorflow 模型导入 Java
Tensorflow model import to Java
我一直在尝试在 Java.
中导入和使用我的训练模型(Tensorflow,Python)
我能够在 Python 中保存模型,但是当我尝试在 Java 中使用相同模型进行预测时遇到问题。
Here,可以看到初始化、训练、保存模型的python代码
Here,您可以看到 Java 用于导入和预测输入值的代码。
我收到的错误信息是:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access0(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
我相信,问题出在 python 代码中的某处,但我找不到它。
您的 python-模型在这方面肯定会失败:
sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))
#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)
init
没有在模型中定义 - 我不确定你想在这个地方实现什么,但这应该给你一个起点
Java importGraphDef()
函数仅导入计算图(由 tf.train.write_graph
在您的 Python 代码中编写),它不会加载经过训练的值变量(存储在检查点中),这就是为什么您会收到抱怨未初始化变量的错误。
TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load
创建使用训练变量值初始化的会话。
要从 Python 导出此格式的模型,您可能需要查看相关问题
在您的情况下,这应该类似于 Python 中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
并通过 save_model(session, x, yhat)
调用它
然后在 Java 中使用以下方式加载模型:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
希望对您有所帮助。
顺便说一句,Deeplearning4j 允许您导入使用 Keras 1.0 在 TensorFlow 上训练的模型(Keras 2.0 支持即将推出)。
https://deeplearning4j.org/model-import-keras
我们还构建了一个名为 Jumpy 的库,它是 Numpy 数组和 Pyjnius 的包装器,它使用指针而不是复制数据,这使得它在处理张量时比 Py4j 更高效。
我一直在尝试在 Java.
中导入和使用我的训练模型(Tensorflow,Python)我能够在 Python 中保存模型,但是当我尝试在 Java 中使用相同模型进行预测时遇到问题。
Here,可以看到初始化、训练、保存模型的python代码
Here,您可以看到 Java 用于导入和预测输入值的代码。
我收到的错误信息是:
Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
[[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access0(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)
我相信,问题出在 python 代码中的某处,但我找不到它。
您的 python-模型在这方面肯定会失败:
sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))
#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)
init
没有在模型中定义 - 我不确定你想在这个地方实现什么,但这应该给你一个起点
Java importGraphDef()
函数仅导入计算图(由 tf.train.write_graph
在您的 Python 代码中编写),它不会加载经过训练的值变量(存储在检查点中),这就是为什么您会收到抱怨未初始化变量的错误。
TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load
创建使用训练变量值初始化的会话。
要从 Python 导出此格式的模型,您可能需要查看相关问题
在您的情况下,这应该类似于 Python 中的以下内容:
def save_model(session, input_tensor, output_tensor):
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
)
b = saved_model_builder.SavedModelBuilder('/tmp/model')
b.add_meta_graph_and_variables(session,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
b.save()
并通过 save_model(session, x, yhat)
然后在 Java 中使用以下方式加载模型:
try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
// b.session().run(...)
}
希望对您有所帮助。
顺便说一句,Deeplearning4j 允许您导入使用 Keras 1.0 在 TensorFlow 上训练的模型(Keras 2.0 支持即将推出)。
https://deeplearning4j.org/model-import-keras
我们还构建了一个名为 Jumpy 的库,它是 Numpy 数组和 Pyjnius 的包装器,它使用指针而不是复制数据,这使得它在处理张量时比 Py4j 更高效。