Java Tensorflow + Keras 等价于 model.predict()

Java Tensorflow + Keras Equivalent of model.predict()

在 python 中,您只需将一个 numpy 数组传递给 predict() 即可从您的模型中获取预测结果。使用 Java 和 SavedModelBundle 的等价物是什么?

Python

model = tf.keras.models.Sequential([
  # layers go here
])
model.compile(...)
model.fit(x_train, y_train)

predictions = model.predict(x_test_maxabs) # <= This line 

Java

SavedModelBundle model = SavedModelBundle.load(path, "serve");
model.predict() // ????? // What does it take as in input? Tensor?

TensorFlow Python 自动将您的 NumPy 数组转换为 tf.Tensor。在 TensorFlow Java 中,您可以直接操作张量。

现在 SavedModelBundle 没有 predict 方法。您需要获取会话并 运行 它,使用 SessionRunner 并为其提供输入张量。

例如,基于下一代 TF Java (https://github.com/tensorflow/java),您的代码最终看起来像这样(请注意,我在这里对 x_test_maxabs 因为你的代码示例没有解释清楚它来自哪里):

try (SavedModelBundle model = SavedModelBundle.load(path, "serve")) {
    try (Tensor<TFloat32> input = TFloat32.tensorOf(...);
        Tensor<TFloat32> output = model.session()
            .runner()
            .feed("input_name", input)
            .fetch("output_name")
            .run()
            .expect(TFloat32.class)) {

        float prediction = output.data().getFloat();
        System.out.println("prediction = " + prediction);
    }        
}

如果您不确定图中的 input/output 张量的名称是什么,您可以通过查看签名定义以编程方式获取:

model.metaGraphDef().getSignatureDefMap().get("serving_default")

你可以试试 Deep Java Library (DJL).

DJL 内部使用 Tensorflow java 并提供高级 API 以简化推理:

Criteria<Image, Classifications> criteria =
    Criteria.builder()
        .setTypes(Image.class, Classifications.class)
        .optModelUrls("https://example.com/squeezenet.zip")
        .optTranslator(ImageClassificationTranslator
               .builder().addTransform(new ToTensor()).build())
        .build();

try (ZooModel<Image, Classification> model = ModelZoo.load(criteria);
        Predictor<Image, Classification> predictor = model.newPredictor()) {
    Image image = ImageFactory.getInstance().fromUrl("https://myimage.jpg");
    Classification result = predictor.predict(image);
}


查看 github 存储库:https://github.com/awslabs/djl

有一篇博文:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6

并且可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md

0.3.1 API:

val model: SavedModelBundle = SavedModelBundle.load("path/to/model", "serve")

val inputTensor = TFloat32.tesnorOf(..)

val function: ConcreteFunction = model.function(Signature.DEFAULT_KEY)
val result: Tensor = function.call(inputTensor) // u can cast to type you expect, a type of returning tensor can be checked by signature: model.function("serving_default").signature().toString()

在你得到任何子类型的结果张量后,你可以迭代它的值。在我的示例中,我有一个形状为 (1, 56)TFloat32,因此我通过 result.get(0, idx)

找到了最大值