在 Java 中从 Tensorflow 服务模型获取矩阵响应
Get Matrix response from Tensorflow serving model in Java
我目前正在 Python 中构建模型并从另一个 Java 客户端获取结果。
我需要知道如何从具有超过 1 个维度的 TensorProto 中获取 float[][]
或 List<List<Float>>
(类似的东西)。
在Python,做这个工作可能很容易:
from tensorflow.python.framework import tensor_util
.
.
.
print tensor_util.MakeNdarray(tensorProto)
===== 更新 =======:
如果 Java 的 tensorProto.getFloatValList()
是由 Python 的 tensor_util.make_tensor_proto(vector)
创建的,那么它也不起作用。
以上情况都可以通过@Ash的回答解决
正如 Allen 在评论中提到的,这可能是一个很好的功能请求。
但在此期间,一个解决方法是构建和 运行 一个解析编码的 protobuf 和 returns 一个 Tensor
的图。它不会特别有效,但你可以这样做:
import org.tensorflow.*;
import java.util.Arrays;
public final class ProtoToTensor {
public static Tensor<Float> tensorFromSerializedProto(byte[] serialized) {
// One may way to cache the Graph and Session as member variables to avoid paying the cost of
// graph and session construction on each call.
try (Graph g = buildGraphToParseProto();
Session sess = new Session(g);
Tensor<String> input = Tensors.create(serialized)) {
return sess.runner()
.feed("input", input)
.fetch("output")
.run()
.get(0)
.expect(Float.class);
}
}
private static Graph buildGraphToParseProto() {
Graph g = new Graph();
// The graph construction process in Java is currently (as of TensorFlow 1.4) very verbose.
// Once https://github.com/tensorflow/tensorflow/issues/7149 is resolved, this should become
// *much* more convenient and succint.
Output<String> in =
g.opBuilder("Placeholder", "input")
.setAttr("dtype", DataType.STRING)
.setAttr("shape", Shape.scalar())
.build()
.output(0);
g.opBuilder("ParseTensor", "output").setAttr("out_type", DataType.FLOAT).addInput(in).build();
return g;
}
public static void main(String[] args) {
// Let's say you got a byte[] representation of the proto somehow.
// In this case, I got it from Python from the following program
// that serializes the 1x1 matrix:
/*
import tensorflow as tf
list(bytearray(tf.make_tensor_proto([[1.]]).SerializeToString()))
*/
byte[] bytes = {8, 1, 18, 8, 18, 2, 8, 1, 18, 2, 8, 1, 42, 4, 0, 0, (byte)128, 63};
try (Tensor<Float> t = tensorFromSerializedProto(bytes)) {
// You can now get an float[][] array using t.copyTo().
// t.shape() gives shape information.
System.out.println("Tensor: " + t);
float[][] f = t.copyTo(new float[1][1]);
System.out.println("float[][]: " + Arrays.deepToString(f));
}
}
}
如您所见,这是使用一些非常低级的 API 来构建图形和会话。有一个用一行替换所有这些的功能请求是合理的:
Tensor<Float> t = Tensor.createFromProto(serialized);
我目前正在 Python 中构建模型并从另一个 Java 客户端获取结果。
我需要知道如何从具有超过 1 个维度的 TensorProto 中获取 float[][]
或 List<List<Float>>
(类似的东西)。
在Python,做这个工作可能很容易:
from tensorflow.python.framework import tensor_util
.
.
.
print tensor_util.MakeNdarray(tensorProto)
===== 更新 =======:
如果Java 的 tensorProto.getFloatValList()
是由 Python 的 tensor_util.make_tensor_proto(vector)
创建的,那么它也不起作用。
以上情况都可以通过@Ash的回答解决
正如 Allen 在评论中提到的,这可能是一个很好的功能请求。
但在此期间,一个解决方法是构建和 运行 一个解析编码的 protobuf 和 returns 一个 Tensor
的图。它不会特别有效,但你可以这样做:
import org.tensorflow.*;
import java.util.Arrays;
public final class ProtoToTensor {
public static Tensor<Float> tensorFromSerializedProto(byte[] serialized) {
// One may way to cache the Graph and Session as member variables to avoid paying the cost of
// graph and session construction on each call.
try (Graph g = buildGraphToParseProto();
Session sess = new Session(g);
Tensor<String> input = Tensors.create(serialized)) {
return sess.runner()
.feed("input", input)
.fetch("output")
.run()
.get(0)
.expect(Float.class);
}
}
private static Graph buildGraphToParseProto() {
Graph g = new Graph();
// The graph construction process in Java is currently (as of TensorFlow 1.4) very verbose.
// Once https://github.com/tensorflow/tensorflow/issues/7149 is resolved, this should become
// *much* more convenient and succint.
Output<String> in =
g.opBuilder("Placeholder", "input")
.setAttr("dtype", DataType.STRING)
.setAttr("shape", Shape.scalar())
.build()
.output(0);
g.opBuilder("ParseTensor", "output").setAttr("out_type", DataType.FLOAT).addInput(in).build();
return g;
}
public static void main(String[] args) {
// Let's say you got a byte[] representation of the proto somehow.
// In this case, I got it from Python from the following program
// that serializes the 1x1 matrix:
/*
import tensorflow as tf
list(bytearray(tf.make_tensor_proto([[1.]]).SerializeToString()))
*/
byte[] bytes = {8, 1, 18, 8, 18, 2, 8, 1, 18, 2, 8, 1, 42, 4, 0, 0, (byte)128, 63};
try (Tensor<Float> t = tensorFromSerializedProto(bytes)) {
// You can now get an float[][] array using t.copyTo().
// t.shape() gives shape information.
System.out.println("Tensor: " + t);
float[][] f = t.copyTo(new float[1][1]);
System.out.println("float[][]: " + Arrays.deepToString(f));
}
}
}
如您所见,这是使用一些非常低级的 API 来构建图形和会话。有一个用一行替换所有这些的功能请求是合理的:
Tensor<Float> t = Tensor.createFromProto(serialized);