Tensorflow 模型导入 Java

Tensorflow model import to Java

我一直在尝试在 Java.

中导入和使用我的训练模型(Tensorflow,Python)

我能够在 Python 中保存模型,但是当我尝试在 Java 中使用相同模型进行预测时遇到问题。

Here,可以看到初始化、训练、保存模型的python代码

Here,您可以看到 Java 用于导入和预测输入值的代码。

我收到的错误信息是:

Exception in thread "main" java.lang.IllegalStateException: Attempting to use uninitialized value Variable_7
     [[Node: Variable_7/read = Identity[T=DT_FLOAT, _class=["loc:@Variable_7"], _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_7)]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access0(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:285)
    at org.tensorflow.Session$Runner.run(Session.java:235)
    at org.tensorflow.examples.Identity_import.main(Identity_import.java:35)

我相信,问题出在 python 代码中的某处,但我找不到它。

您的 python-模型在这方面肯定会失败:

sess.run(init) #<---this will fail
save_model(sess)
error = tf.reduce_mean(tf.square(prediction - y))

#accuracy = tf.reduce_mean(tf.cast(error, 'float'))
print('Error:', error)

init 没有在模型中定义 - 我不确定你想在这个地方实现什么,但这应该给你一个起点

Java importGraphDef() 函数仅导入计算图(由 tf.train.write_graph 在您的 Python 代码中编写),它不会加载经过训练的值变量(存储在检查点中),这就是为什么您会收到抱怨未初始化变量的错误。

TensorFlow SavedModel format on the other hand includes all information about a model (graph, checkpoint state, other metadata) and to use in Java you'd want to use SavedModelBundle.load 创建使用训练变量值初始化的会话。

要从 Python 导出此格式的模型,您可能需要查看相关问题

在您的情况下,这应该类似于 Python 中的以下内容:

def save_model(session, input_tensor, output_tensor):
  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'input': tf.saved_model.utils.build_tensor_info(input_tensor)},
    outputs = {'output': tf.saved_model.utils.build_tensor_info(output_tensor)},
  )
  b = saved_model_builder.SavedModelBuilder('/tmp/model')
  b.add_meta_graph_and_variables(session,
                                 [tf.saved_model.tag_constants.SERVING],
                                 signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature})
  b.save() 

并通过 save_model(session, x, yhat)

调用它

然后在 Java 中使用以下方式加载模型:

try (SavedModelBundle b = SavedModelBundle.load("/tmp/mymodel", "serve")) {
  // b.session().run(...)
}

希望对您有所帮助。

顺便说一句,Deeplearning4j 允许您导入使用 Keras 1.0 在 TensorFlow 上训练的模型(Keras 2.0 支持即将推出)。

https://deeplearning4j.org/model-import-keras

我们还构建了一个名为 Jumpy 的库,它是 Numpy 数组和 Pyjnius 的包装器,它使用指针而不是复制数据,这使得它在处理张量时比 Py4j 更高效。

https://deeplearning4j.org/jumpy