Java Tensorflow + Keras 等价于 model.predict()
Java Tensorflow + Keras Equivalent of model.predict()
在 python 中,您只需将一个 numpy 数组传递给 predict()
即可从您的模型中获取预测结果。使用 Java 和 SavedModelBundle
的等价物是什么?
Python
model = tf.keras.models.Sequential([
# layers go here
])
model.compile(...)
model.fit(x_train, y_train)
predictions = model.predict(x_test_maxabs) # <= This line
Java
SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
TensorFlow Python 自动将您的 NumPy 数组转换为 tf.Tensor
。在 TensorFlow Java 中,您可以直接操作张量。
现在 SavedModelBundle
没有 predict
方法。您需要获取会话并 运行 它,使用 SessionRunner
并为其提供输入张量。
例如,基于下一代 TF Java (https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,我在这里对 x_test_maxabs
因为你的代码示例没有解释清楚它来自哪里):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
如果您不确定图中的 input/output 张量的名称是什么,您可以通过查看签名定义以编程方式获取:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
你可以试试 Deep Java Library (DJL).
DJL 内部使用 Tensorflow java 并提供高级 API 以简化推理:
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls("https://example.com/squeezenet.zip")
.optTranslator(ImageClassificationTranslator
.builder().addTransform(new ToTensor()).build())
.build();
try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
Predictor<Image, Classification> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
Classification result = predictor.predict(image);
}
查看 github 存储库:https://github.com/awslabs/djl
有一篇博文:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6
并且可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md
在0.3.1
API:
val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")
val inputTensor = TFloat32.tesnorOf(..)
val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()
在你得到任何子类型的结果张量后,你可以迭代它的值。在我的示例中,我有一个形状为 (1, 56)
的 TFloat32
,因此我通过 result.get(0, idx)
找到了最大值
在 python 中,您只需将一个 numpy 数组传递给 predict()
即可从您的模型中获取预测结果。使用 Java 和 SavedModelBundle
的等价物是什么?
Python
model = tf.keras.models.Sequential([
# layers go here
])
model.compile(...)
model.fit(x_train, y_train)
predictions = model.predict(x_test_maxabs) # <= This line
Java
SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?
TensorFlow Python 自动将您的 NumPy 数组转换为 tf.Tensor
。在 TensorFlow Java 中,您可以直接操作张量。
现在 SavedModelBundle
没有 predict
方法。您需要获取会话并 运行 它,使用 SessionRunner
并为其提供输入张量。
例如,基于下一代 TF Java (https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,我在这里对 x_test_maxabs
因为你的代码示例没有解释清楚它来自哪里):
try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
Tensor<TFloat32> output = model.session()
.runner()
.feed("input_name", input)
.fetch("output_name")
.run()
.expect(TFloat32.class)) {
float prediction = output.data().getFloat();
System.out.println("prediction = " + prediction);
}
}
如果您不确定图中的 input/output 张量的名称是什么,您可以通过查看签名定义以编程方式获取:
model.metaGraphDef().getSignatureDefMap().get("serving_default")
你可以试试 Deep Java Library (DJL).
DJL 内部使用 Tensorflow java 并提供高级 API 以简化推理:
Criteria<Image, Classifications> criteria =
Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelUrls("https://example.com/squeezenet.zip")
.optTranslator(ImageClassificationTranslator
.builder().addTransform(new ToTensor()).build())
.build();
try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
Predictor<Image, Classification> predictor = model.newPredictor()) {
Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
Classification result = predictor.predict(image);
}
查看 github 存储库:https://github.com/awslabs/djl
有一篇博文:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6
并且可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md
在0.3.1
API:
val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")
val inputTensor = TFloat32.tesnorOf(..)
val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()
在你得到任何子类型的结果张量后,你可以迭代它的值。在我的示例中,我有一个形状为 (1, 56)
的 TFloat32
,因此我通过 result.get(0, idx)