Java gRPC 客户端预测调用 half_plus_two 示例模型

Java gRPC client predict call to half_plus_two example model

我正在尝试从 Java 客户端调用 Tensorflow Serving。 运行 模型是 half_plus_two 示例模型。我可以成功进行 REST 调用。但无法进行 gRPC 等效调用。

我尝试将一个字符串作为模型输入以及一个浮点数数组传递给张量原型生成器。当我打印出来时,张量原型似乎包含正确的数据: [1.0, 2.0, 5.0]

String host = "localhost";
        int port = 8500;
        // the model's name.
        String modelName = "half_plus_two";
        // model's version
        long modelVersion = 123;
        // assume this model takes input of free text, and make some sentiment prediction.
//        String modelInput = "some text input to make prediction with";
        String modelInput = "{\"instances\": [1.0, 2.0, 5.0]";

        // create a channel
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
        tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);

        // create a modelspec
        tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
        modelSpecBuilder.setSignatureName("serving_default");

        Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
        builder.setModelSpec(modelSpecBuilder);

        // create the TensorProto and request

        float[] floatData = new float[3];
        floatData[0] = 1.0f;
        floatData[1] = 2.0f;
        floatData[2] = 5.0f;


        org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
        org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));
        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

        // Set the float_val field.
        for (int i = 0; i < floatData.length; i++) {
            tensorProtoBuilder.addFloatVal(floatData[i]);
        }


        org.tensorflow.framework.TensorProto tp = tensorProtoBuilder.build();

        System.out.println(tp.getFloatValList());

        builder.putInputs("inputs", tp);

        Predict.PredictRequest request = builder.build();
        Predict.PredictResponse response = stub.predict(request);

当我打印请求时,形状是:

model_spec {
  name: "half_plus_two"
  version {
    value: 123
  }
  signature_name: "serving_default"
}
inputs {
  key: "inputs"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: -1
      }
      dim {
        size: 1
      }
    }
    float_val: 1.0
    float_val: 2.0
    float_val: 5.0
  }
}

得到这个异常:

Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.
    at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:233)
    at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:214)
    at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:139)
    at tensorflow.serving.PredictionServiceGrpc$PredictionServiceBlockingStub.predict(PredictionServiceGrpc.java:446)
    at com.avaya.ccml.grpc.GrpcClient.main(GrpcClient.java:72)`

编辑: 仍在努力。

看来我提供的张量原型不正确。

用 saved_model_cli 进行了检查,它显示了正确的形状:

The given SavedModel SignatureDef contains the following input(s):
  inputs['x'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: x:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['y'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: y:0
Method name is: tensorflow/serving/predict

所以接下来需要弄清楚如何创建这种结构的张量原型

当前

在 half_plus_two 的示例中,他们使用实例标签作为输入值; https://www.tensorflow.org/tfx/serving/docker#serving_example

你能尝试将它设置为这样的实例吗?

    builder.putInputs("instances", tp);

我也认为 DType 可能有问题。我认为您应该使用 DT_FLOAT 而不是 DT_STRING,因为检查结果显示

    tensorProtoBuilder.setDtype(DataType.DT_FLOAT);

编辑

我正在与 Python 合作,无法发现你的错误,但是,这就是我们发送预测请求的方式(使用 PredictRequest 原型)。也许你可以试试 Predict proto,或者我遗漏了一些东西,你可能会自己发现不同之处

request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = signature_name
request.inputs['x'].dtype = types_pb2.DT_FLOAT
request.inputs['x'].float_val.append(2.0)

channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Predict(request, RPC_TIMEOUT)

我想通了。

答案一直盯着我看。

异常说明输入签名必须是'x'

Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.

并且 CLI 的输出还查找 'x' 作为输入名称

The given SavedModel SignatureDef contains the following input(s):
  inputs['x'] tensor_info:

所以我换了行

requestBuilder.putInputs("inputs", proto);

requestBuilder.putInputs("x", proto);

完整的工作代码

import com.google.protobuf.Int64Value;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.tensorflow.framework.DataType;
import tensorflow.serving.Predict;

public class GrpcClient {
    public static void main(String[] args) {
        String host = "localhost";
        int port = 8500;
        // the model's name.
        String modelName = "half_plus_two";
        // model's version
        long modelVersion = 123;

        // create a channel
        ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
        tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);

        // create PredictRequest
        Predict.PredictRequest.Builder requestBuilder = Predict.PredictRequest.newBuilder();

        // create ModelSpec
        tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
        modelSpecBuilder.setName(modelName);
        modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
        modelSpecBuilder.setSignatureName("serving_default");

        // set model for request
        requestBuilder.setModelSpec(modelSpecBuilder);

        // create TensorProto with 3 floats
        org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
        tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
        tensorProtoBuilder.addFloatVal(1.0f);
        tensorProtoBuilder.addFloatVal(2.0f);
        tensorProtoBuilder.addFloatVal(5.0f);

        // create TensorShapeProto
        org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
        tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));

        // set shape for proto
        tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());

        // build proto
        org.tensorflow.framework.TensorProto proto = tensorProtoBuilder.build();

        // set proto for request
        requestBuilder.putInputs("x", proto);

        // build request
        Predict.PredictRequest request = requestBuilder.build();
        System.out.println("Printing request \n" + request.toString());

        // run predict
        Predict.PredictResponse response = stub.predict(request);
        System.out.println(response.toString());
    }
}