Java gRPC 客户端预测调用 half_plus_two 示例模型
Java gRPC client predict call to half_plus_two example model
我正在尝试从 Java 客户端调用 Tensorflow Serving。 运行 模型是 half_plus_two 示例模型。我可以成功进行 REST 调用。但无法进行 gRPC 等效调用。
我尝试将一个字符串作为模型输入以及一个浮点数数组传递给张量原型生成器。当我打印出来时,张量原型似乎包含正确的数据:
[1.0, 2.0, 5.0]
String host = "localhost";
int port = 8500;
// the model's name.
String modelName = "half_plus_two";
// model's version
long modelVersion = 123;
// assume this model takes input of free text, and make some sentiment prediction.
// String modelInput = "some text input to make prediction with";
String modelInput = "{\"instances\": [1.0, 2.0, 5.0]";
// create a channel
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);
// create a modelspec
tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
modelSpecBuilder.setSignatureName("serving_default");
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
// create the TensorProto and request
float[] floatData = new float[3];
floatData[0] = 1.0f;
floatData[1] = 2.0f;
floatData[2] = 5.0f;
org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
// Set the float_val field.
for (int i = 0; i < floatData.length; i++) {
tensorProtoBuilder.addFloatVal(floatData[i]);
}
org.tensorflow.framework.TensorProto tp = tensorProtoBuilder.build();
System.out.println(tp.getFloatValList());
builder.putInputs("inputs", tp);
Predict.PredictRequest request = builder.build();
Predict.PredictResponse response = stub.predict(request);
当我打印请求时,形状是:
model_spec {
name: "half_plus_two"
version {
value: 123
}
signature_name: "serving_default"
}
inputs {
key: "inputs"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 1
}
}
float_val: 1.0
float_val: 2.0
float_val: 5.0
}
}
得到这个异常:
Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.
at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:233)
at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:214)
at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:139)
at tensorflow.serving.PredictionServiceGrpc$PredictionServiceBlockingStub.predict(PredictionServiceGrpc.java:446)
at com.avaya.ccml.grpc.GrpcClient.main(GrpcClient.java:72)`
编辑:
仍在努力。
看来我提供的张量原型不正确。
用 saved_model_cli 进行了检查,它显示了正确的形状:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
所以接下来需要弄清楚如何创建这种结构的张量原型
当前
在 half_plus_two 的示例中,他们使用实例标签作为输入值; https://www.tensorflow.org/tfx/serving/docker#serving_example
你能尝试将它设置为这样的实例吗?
builder.putInputs("instances", tp);
我也认为 DType 可能有问题。我认为您应该使用 DT_FLOAT 而不是 DT_STRING,因为检查结果显示
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
编辑
我正在与 Python 合作,无法发现你的错误,但是,这就是我们发送预测请求的方式(使用 PredictRequest
原型)。也许你可以试试 Predict proto,或者我遗漏了一些东西,你可能会自己发现不同之处
request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = signature_name
request.inputs['x'].dtype = types_pb2.DT_FLOAT
request.inputs['x'].float_val.append(2.0)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Predict(request, RPC_TIMEOUT)
我想通了。
答案一直盯着我看。
异常说明输入签名必须是'x'
Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.
并且 CLI 的输出还查找 'x' 作为输入名称
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
所以我换了行
requestBuilder.putInputs("inputs", proto);
到
requestBuilder.putInputs("x", proto);
完整的工作代码
import com.google.protobuf.Int64Value;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.tensorflow.framework.DataType;
import tensorflow.serving.Predict;
public class GrpcClient {
public static void main(String[] args) {
String host = "localhost";
int port = 8500;
// the model's name.
String modelName = "half_plus_two";
// model's version
long modelVersion = 123;
// create a channel
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);
// create PredictRequest
Predict.PredictRequest.Builder requestBuilder = Predict.PredictRequest.newBuilder();
// create ModelSpec
tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
modelSpecBuilder.setSignatureName("serving_default");
// set model for request
requestBuilder.setModelSpec(modelSpecBuilder);
// create TensorProto with 3 floats
org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
tensorProtoBuilder.addFloatVal(1.0f);
tensorProtoBuilder.addFloatVal(2.0f);
tensorProtoBuilder.addFloatVal(5.0f);
// create TensorShapeProto
org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));
// set shape for proto
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
// build proto
org.tensorflow.framework.TensorProto proto = tensorProtoBuilder.build();
// set proto for request
requestBuilder.putInputs("x", proto);
// build request
Predict.PredictRequest request = requestBuilder.build();
System.out.println("Printing request \n" + request.toString());
// run predict
Predict.PredictResponse response = stub.predict(request);
System.out.println(response.toString());
}
}
我正在尝试从 Java 客户端调用 Tensorflow Serving。 运行 模型是 half_plus_two 示例模型。我可以成功进行 REST 调用。但无法进行 gRPC 等效调用。
我尝试将一个字符串作为模型输入以及一个浮点数数组传递给张量原型生成器。当我打印出来时,张量原型似乎包含正确的数据: [1.0, 2.0, 5.0]
String host = "localhost";
int port = 8500;
// the model's name.
String modelName = "half_plus_two";
// model's version
long modelVersion = 123;
// assume this model takes input of free text, and make some sentiment prediction.
// String modelInput = "some text input to make prediction with";
String modelInput = "{\"instances\": [1.0, 2.0, 5.0]";
// create a channel
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);
// create a modelspec
tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
modelSpecBuilder.setSignatureName("serving_default");
Predict.PredictRequest.Builder builder = Predict.PredictRequest.newBuilder();
builder.setModelSpec(modelSpecBuilder);
// create the TensorProto and request
float[] floatData = new float[3];
floatData[0] = 1.0f;
floatData[1] = 2.0f;
floatData[2] = 5.0f;
org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
// Set the float_val field.
for (int i = 0; i < floatData.length; i++) {
tensorProtoBuilder.addFloatVal(floatData[i]);
}
org.tensorflow.framework.TensorProto tp = tensorProtoBuilder.build();
System.out.println(tp.getFloatValList());
builder.putInputs("inputs", tp);
Predict.PredictRequest request = builder.build();
Predict.PredictResponse response = stub.predict(request);
当我打印请求时,形状是:
model_spec {
name: "half_plus_two"
version {
value: 123
}
signature_name: "serving_default"
}
inputs {
key: "inputs"
value {
dtype: DT_FLOAT
tensor_shape {
dim {
size: -1
}
dim {
size: 1
}
}
float_val: 1.0
float_val: 2.0
float_val: 5.0
}
}
得到这个异常:
Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.
at io.grpc.stub.ClientCalls.toStatusRuntimeException(ClientCalls.java:233)
at io.grpc.stub.ClientCalls.getUnchecked(ClientCalls.java:214)
at io.grpc.stub.ClientCalls.blockingUnaryCall(ClientCalls.java:139)
at tensorflow.serving.PredictionServiceGrpc$PredictionServiceBlockingStub.predict(PredictionServiceGrpc.java:446)
at com.avaya.ccml.grpc.GrpcClient.main(GrpcClient.java:72)`
编辑: 仍在努力。
看来我提供的张量原型不正确。
用 saved_model_cli 进行了检查,它显示了正确的形状:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
所以接下来需要弄清楚如何创建这种结构的张量原型
当前
在 half_plus_two 的示例中,他们使用实例标签作为输入值; https://www.tensorflow.org/tfx/serving/docker#serving_example
你能尝试将它设置为这样的实例吗?
builder.putInputs("instances", tp);
我也认为 DType 可能有问题。我认为您应该使用 DT_FLOAT 而不是 DT_STRING,因为检查结果显示
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
编辑
我正在与 Python 合作,无法发现你的错误,但是,这就是我们发送预测请求的方式(使用 PredictRequest
原型)。也许你可以试试 Predict proto,或者我遗漏了一些东西,你可能会自己发现不同之处
request = predict_pb2.PredictRequest()
request.model_spec.name = model_name
request.model_spec.signature_name = signature_name
request.inputs['x'].dtype = types_pb2.DT_FLOAT
request.inputs['x'].float_val.append(2.0)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Predict(request, RPC_TIMEOUT)
我想通了。
答案一直盯着我看。
异常说明输入签名必须是'x'
Exception in thread "main" io.grpc.StatusRuntimeException: INVALID_ARGUMENT: input tensor alias not found in signature: inputs. Inputs expected to be in the set {x}.
并且 CLI 的输出还查找 'x' 作为输入名称
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
所以我换了行
requestBuilder.putInputs("inputs", proto);
到
requestBuilder.putInputs("x", proto);
完整的工作代码
import com.google.protobuf.Int64Value;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import org.tensorflow.framework.DataType;
import tensorflow.serving.Predict;
public class GrpcClient {
public static void main(String[] args) {
String host = "localhost";
int port = 8500;
// the model's name.
String modelName = "half_plus_two";
// model's version
long modelVersion = 123;
// create a channel
ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
tensorflow.serving.PredictionServiceGrpc.PredictionServiceBlockingStub stub = tensorflow.serving.PredictionServiceGrpc.newBlockingStub(channel);
// create PredictRequest
Predict.PredictRequest.Builder requestBuilder = Predict.PredictRequest.newBuilder();
// create ModelSpec
tensorflow.serving.Model.ModelSpec.Builder modelSpecBuilder = tensorflow.serving.Model.ModelSpec.newBuilder();
modelSpecBuilder.setName(modelName);
modelSpecBuilder.setVersion(Int64Value.of(modelVersion));
modelSpecBuilder.setSignatureName("serving_default");
// set model for request
requestBuilder.setModelSpec(modelSpecBuilder);
// create TensorProto with 3 floats
org.tensorflow.framework.TensorProto.Builder tensorProtoBuilder = org.tensorflow.framework.TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
tensorProtoBuilder.addFloatVal(1.0f);
tensorProtoBuilder.addFloatVal(2.0f);
tensorProtoBuilder.addFloatVal(5.0f);
// create TensorShapeProto
org.tensorflow.framework.TensorShapeProto.Builder tensorShapeBuilder = org.tensorflow.framework.TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(org.tensorflow.framework.TensorShapeProto.Dim.newBuilder().setSize(3));
// set shape for proto
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
// build proto
org.tensorflow.framework.TensorProto proto = tensorProtoBuilder.build();
// set proto for request
requestBuilder.putInputs("x", proto);
// build request
Predict.PredictRequest request = requestBuilder.build();
System.out.println("Printing request \n" + request.toString());
// run predict
Predict.PredictResponse response = stub.predict(request);
System.out.println(response.toString());
}
}