Python 中的 Tensorflow Java Api `toGraphDef` 是什么?
What is the Tensorflow Java Api `toGraphDef` equivalent in Python?
我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。
我以此为例:tensorflow/examples/LabelImage.java
这是我的简单 Scala 代码:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存我的模型以将会话和图形存储在同一个文件中。如上文 "PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb" 所述。
描述 here 它提到:
The serialized representation of the graph, often referred to as a
GraphDef, can be generated by toGraphDef() and equivalents in other
language APIs.
其他语言 API 中的等效项是什么?我觉得不明显
注意:我已经查看了 tensorflow_serving 下的 mnist_saved_model.py,但通过该过程保存它会得到一个 .pb
文件和一个 variables
文件夹。尝试加载 .pb
文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef
目前使用 tensorflow 的 Java API,我只找到了如何将图形保存为 graphDef(即没有其变量和元数据)。这可以通过将 Array[Byte] 写入文件来完成:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
此处 myGraph
是来自 Graph class 的 java 对象。
我建议使用 java SavedModel api defined here. It will save your model in a folder with both the serialized graph in a .pb file and the variables in a folder. Note the tag_constants you use as you'll need it in your scala/java code to load the model with the variables. Then the graph and session with variables are easily loaded with the SavedModelBundle java class Python API 保存您的模型 api。它 returns 你是一个包装器,包含图形和包含变量值的会话:
val model = SavedModelBundle.load(modelDir, modelTag)
如果您已经尝试过,也许您可以分享您的代码以了解它返回无效 GraphDef 的原因。
另一种选择是冻结你的图表,即将你的变量节点变成常量节点,这样一切都在 .pb 文件中是独立的。更多信息 here 冷冻部分
我正在使用 Tensorflow Java Api 将已创建的 Tensorflow 模型加载到 JVM 中。 我以此为例:tensorflow/examples/LabelImage.java
这是我的简单 Scala 代码:
import java.nio.file.{Files, Path, Paths}
import org.tensorflow.{Graph, Session, Tensor}
def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path)
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb"))
val g = new Graph()
g.importGraphDef(graphDef)
val session = new Session(g)
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0))
如何保存我的模型以将会话和图形存储在同一个文件中。如上文 "PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb" 所述。
描述 here 它提到:
The serialized representation of the graph, often referred to as a GraphDef, can be generated by toGraphDef() and equivalents in other language APIs.
其他语言 API 中的等效项是什么?我觉得不明显
注意:我已经查看了 tensorflow_serving 下的 mnist_saved_model.py,但通过该过程保存它会得到一个 .pb
文件和一个 variables
文件夹。尝试加载 .pb
文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef
目前使用 tensorflow 的 Java API,我只找到了如何将图形保存为 graphDef(即没有其变量和元数据)。这可以通过将 Array[Byte] 写入文件来完成:
Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef)
此处 myGraph
是来自 Graph class 的 java 对象。
我建议使用 java SavedModel api defined here. It will save your model in a folder with both the serialized graph in a .pb file and the variables in a folder. Note the tag_constants you use as you'll need it in your scala/java code to load the model with the variables. Then the graph and session with variables are easily loaded with the SavedModelBundle java class Python API 保存您的模型 api。它 returns 你是一个包装器,包含图形和包含变量值的会话:
val model = SavedModelBundle.load(modelDir, modelTag)
如果您已经尝试过,也许您可以分享您的代码以了解它返回无效 GraphDef 的原因。
另一种选择是冻结你的图表,即将你的变量节点变成常量节点,这样一切都在 .pb 文件中是独立的。更多信息 here 冷冻部分