使用 Java 与 Tensorflow 服务进行推理
Inferencing with Tensorflow Serving using Java
我们正在转换现有的 Java 生产代码以使用 Tensorflow Serving (TFS) 进行推理。我们已经重新训练了我们的模型并使用新的 SavedModel 格式保存了它们(不再有冻结的图表!!)。
从我阅读的文档来看,TFS 不直接支持 Java。但是它确实提供了一个 gRPC 接口,并且确实提供了一个 Java 接口。
我的问题是,启动 Java 应用程序以使用 TFS 涉及哪些步骤。
[编辑:将步骤移至解决方案]
由于文档和示例仍然有限,因此花了四天时间拼凑起来。
我确信有更好的方法可以做到这一点,但这是我目前发现的:
- 我在 github 上克隆了
tensorflow/tensorflow
、tensorflow/serving
和 google/protobuf
存储库。
- 我使用
protoc
protobuf compiler with the grpc-java
plugin 编译了以下 protobuf 文件。我讨厌这样一个事实,即有这么多分散的 .proto
文件要编译,但我希望包含最小集,并且在各个目录中有这么多不需要的 .proto
文件,这些文件本应在. 这是我编译 Java 应用程序所需的最小集合:
serving_repo/tensorflow_serving/apis/*.proto
serving_repo/tensorflow_serving/config/model_server_config.proto
serving_repo/tensorflow_serving/core/logging.proto
serving_repo/tensorflow_serving/core/logging_config.proto
serving_repo/tensorflow_serving/util/status.proto
serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto
serving_repo/tensorflow_serving/config/log_collector_config.proto
tensorflow_repo/tensorflow/core/framework/tensor.proto
tensorflow_repo/tensorflow/core/framework/tensor_shape.proto
tensorflow_repo/tensorflow/core/framework/types.proto
tensorflow_repo/tensorflow/core/framework/resource_handle.proto
tensorflow_repo/tensorflow/core/example/example.proto
tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto
tensorflow_repo/tensorflow/core/example/feature.proto
tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto
tensorflow_repo/tensorflow/core/protobuf/config.proto
- 请注意,即使存在 OUT
grpc-java
,protoc
也会编译,但是大多数关键入口点将神秘地丢失。如果缺少 PredictionServiceGrpc.java
,则不会执行 grpc-java
。
- 命令行示例(插入换行符以提高可读性):
$ ./protoc -I=/Users/foobar/protobuf_repo/src \
-I=/Users/foobar/tensorflow_repo \
-I=/Users/foobar/tfserving_repo \
-plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe \
--java_out=src \
--grpc-java_out=src \
/Users/foobar/tfserving_repo/tensorflow_serving/apis/*.proto
- 按照 gRPC documentation,我创建了一个通道和一个存根:
ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
- 我按照几个文档拼凑了以下步骤:
- Maven 导入是:
io.grpc:grpc-all
org.tensorflow:libtensorflow
org.tensorflow:proto
com.google.protobuf:protobuf-java
- 示例代码如下:
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();
// Now prepare for the inference request over gRPC to the TF Serving server
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(mGraphVersion).build();
Model.ModelSpec.Builder model = Model.ModelSpec
.newBuilder()
.setName(mGraphName)
.setVersion(version); // type = Int64Value
Model.ModelSpec modelSpec = model.build();
Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
.setModelSpec(modelSpec)
.putInputs("image", featuresTensorProto)
.build();
Predict.PredictResponse response;
try {
response = mBlockingstub.predict(request);
// Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java
java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
System.out.println("Response with the key: " + entry.getKey() + ", value: " + entry.getValue());
}
} catch (StatusRuntimeException e) {
logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
success = false;
}
我们正在转换现有的 Java 生产代码以使用 Tensorflow Serving (TFS) 进行推理。我们已经重新训练了我们的模型并使用新的 SavedModel 格式保存了它们(不再有冻结的图表!!)。
从我阅读的文档来看,TFS 不直接支持 Java。但是它确实提供了一个 gRPC 接口,并且确实提供了一个 Java 接口。
我的问题是,启动 Java 应用程序以使用 TFS 涉及哪些步骤。
[编辑:将步骤移至解决方案]
由于文档和示例仍然有限,因此花了四天时间拼凑起来。
我确信有更好的方法可以做到这一点,但这是我目前发现的:
- 我在 github 上克隆了
tensorflow/tensorflow
、tensorflow/serving
和google/protobuf
存储库。 - 我使用
protoc
protobuf compiler with thegrpc-java
plugin 编译了以下 protobuf 文件。我讨厌这样一个事实,即有这么多分散的.proto
文件要编译,但我希望包含最小集,并且在各个目录中有这么多不需要的.proto
文件,这些文件本应在. 这是我编译 Java 应用程序所需的最小集合:serving_repo/tensorflow_serving/apis/*.proto
serving_repo/tensorflow_serving/config/model_server_config.proto
serving_repo/tensorflow_serving/core/logging.proto
serving_repo/tensorflow_serving/core/logging_config.proto
serving_repo/tensorflow_serving/util/status.proto
serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto
serving_repo/tensorflow_serving/config/log_collector_config.proto
tensorflow_repo/tensorflow/core/framework/tensor.proto
tensorflow_repo/tensorflow/core/framework/tensor_shape.proto
tensorflow_repo/tensorflow/core/framework/types.proto
tensorflow_repo/tensorflow/core/framework/resource_handle.proto
tensorflow_repo/tensorflow/core/example/example.proto
tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto
tensorflow_repo/tensorflow/core/example/feature.proto
tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto
tensorflow_repo/tensorflow/core/protobuf/config.proto
- 请注意,即使存在 OUT
grpc-java
,protoc
也会编译,但是大多数关键入口点将神秘地丢失。如果缺少PredictionServiceGrpc.java
,则不会执行grpc-java
。 - 命令行示例(插入换行符以提高可读性):
$ ./protoc -I=/Users/foobar/protobuf_repo/src \
-I=/Users/foobar/tensorflow_repo \
-I=/Users/foobar/tfserving_repo \
-plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe \
--java_out=src \
--grpc-java_out=src \
/Users/foobar/tfserving_repo/tensorflow_serving/apis/*.proto
- 按照 gRPC documentation,我创建了一个通道和一个存根:
ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
- 我按照几个文档拼凑了以下步骤:
- Maven 导入是:
io.grpc:grpc-all
org.tensorflow:libtensorflow
org.tensorflow:proto
com.google.protobuf:protobuf-java
- 示例代码如下:
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();
// Now prepare for the inference request over gRPC to the TF Serving server
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(mGraphVersion).build();
Model.ModelSpec.Builder model = Model.ModelSpec
.newBuilder()
.setName(mGraphName)
.setVersion(version); // type = Int64Value
Model.ModelSpec modelSpec = model.build();
Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
.setModelSpec(modelSpec)
.putInputs("image", featuresTensorProto)
.build();
Predict.PredictResponse response;
try {
response = mBlockingstub.predict(request);
// Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java
java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
System.out.println("Response with the key: " + entry.getKey() + ", value: " + entry.getValue());
}
} catch (StatusRuntimeException e) {
logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
success = false;
}