如何在 Java 中为 TensorFlow 创建 TensorProto?

How can I create TensorProto for TensorFlow in Java?

现在我们正在使用 tensorflow/serving 进行推理。它公开了 gRPC 服务,我们可以从 proto 文件中生成 Java 类。

现在我们可以从 https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto 生成 PreditionService 但我如何从多维数组构造 TensorProto 对象。

我们有一些来自 Python ndarray 和 C++ 的示例。如果有人在 Java 中尝试过,那就太好了。

Java 中有一些关于 运行 TensorFlow 的工作。这是 blog 但我不确定它是否有效或我们如何在没有依赖项的情况下使用它。

TensorProto支持张量内容的两种表示:

  1. 各种repeated *_val字段(如TensorProto.float_valTensorProto.int_val),将内容存储为原始元素的线性数组,以行为主订单.

  2. TensorProto.tensor_content字段,将内容存储为单字节数组,对应tensorflow::Tensor::AsProtoTensorContent()的结果。 (一般情况下,这种表示对应于tensorflow::Tensor的内存中表示,转换为字节数组,但DT_STRING类型的处理方式不同。)

使用第一种格式生成 TensorProto 对象可能会更容易,但效率较低。假设您的 Java 程序中有一个名为 tensorData 的二维 float 数组,您可以使用以下代码作为起点:

float[][] tensorData = ...;
TensorProto.Builder builder = TensorProto.newBuilder();

// Set the shape and dtype fields.
// ...

// Set the float_val field.
for (int i = 0; i < tensorData.length; ++i) {
    for (int j = 0; j < tensorData[i].length; ++j) {
        builder.addFloatVal(tensorData[i][j]);
    }
}

TensorProto tensorProto = builder.build();