如何使用 golang 将带有 shape=[?] 的输入字符串提供给 tensorflow 模型
How to feed input string with shape=[?] to tensorflow model using golang
Python火车模型代码:
input_schema = dataset_schema.from_feature_spec({
REVIEW_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.string),
LABEL_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.int64)
})
在 python 中预测工作正常。客户示例:
loaded_model = tf.saved_model.loader.load(sess, ["serve"], '/tmp/model/export/Servo/1506084916')
input_dict, output_dict =_signature_def_to_tensors(loaded_model.signature_def['default_input_alternative:None'])
start = datetime.datetime.now()
out = sess.run(output_dict, feed_dict={input_dict["inputs"]: ("I went and saw this movie last night",)})
print(out)
print("Time all: ", datetime.datetime.now() - start)
但是golang客户端不工作:
m, err := tf.LoadSavedModel("/tmp/model/export/Servo/1506084916", []string{"serve"}, &tf.SessionOptions{})
if err != nil {
panic(fmt.Errorf("load model: %s", err))
}
data := "I went and saw this movie last night"
t, err := tf.NewTensor([]string{data})
if err != nil {
panic(fmt.Errorf("tensor err: %s", err))
}
fmt.Printf("tensor: %v", t.Shape())
output, err = m.Session.Run(
map[tf.Output]*tf.Tensor{
m.Graph.Operation("save_1/StringJoin/inputs_1").Output(0): t,
}, []tf.Output{
m.Graph.Operation("linear/binary_logistic_head/predictions/classes").Output(0),
}, nil,
)
if err != nil {
panic(fmt.Errorf("run model: %s", err))
}
我收到错误:
panic: run model: You must feed a value for placeholder tensor
'Placeholder' with dtype string and shape [?]
[[Node: Placeholder = Placeholder_output_shapes=[[?]], dtype=DT_STRING, shape=[?],
_device="/job:localhost/replica:0/task:0/cpu:0"]]
如何用 golang 呈现 shape=[?]
张量?或者我需要更改 python 训练脚本的输入格式?
更新:
这个字符串"save_1/StringJoin/inputs_1"
我在运行之后收到这个python-代码:
for n in sess.graph.as_graph_def().node:
if "inputs" in n.name:
print(n.name)
输出:
transform/transform/inputs/review/Placeholder
transform/transform/inputs/review/Identity
transform/transform/inputs/label/Placeholder
transform/transform/inputs/label/Identity
transform/transform_1/inputs/review/Placeholder
transform/transform_1/inputs/review/Identity
transform/transform_1/inputs/label/Placeholder
transform/transform_1/inputs/label/Identity
save_1/StringJoin/inputs_1
save_2/StringJoin/inputs_1
错误告诉您 You must feed a value for placeholder tensor 'Placeholder'
:这意味着在您为该占位符提供值之前无法构建图表。
在您的 python 代码中,您将其输入行:
input_dict["inputs"]: ("I went and saw this movie last night",)
实际上,input_dict["inputs"]
的计算结果为:<tf.Tensor 'Placeholder:0' shape=(?,) dtype=string>
。
相反,在您的 Go 代码中,您正在寻找一个名为 save_1/StringJoin/inputs_1
的张量,它不是占位符。
要遵循的规则是:在 Python 和 Go 中使用相同的输入。
要解决这个问题,因此,你只需要从图中提取名为Placeholder
的占位符(就像在python中一样),然后使用它。
m.Graph.Operation("Placeholder").Output(0): t,
此外,我建议您使用更完整且易于使用的 tensorflow 包装器 API:tfgo
还有一件事。我阅读了 TF 文档并发现了这个 topic
它有助于找到正确的 input/output 键,响应示例:
The given SavedModel SignatureDef contains the following input(s): inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder:0
PS/ 发布为关注的答案
Python火车模型代码:
input_schema = dataset_schema.from_feature_spec({
REVIEW_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.string),
LABEL_COLUMN: tf.FixedLenFeature(shape=[], dtype=tf.int64)
})
在 python 中预测工作正常。客户示例:
loaded_model = tf.saved_model.loader.load(sess, ["serve"], '/tmp/model/export/Servo/1506084916')
input_dict, output_dict =_signature_def_to_tensors(loaded_model.signature_def['default_input_alternative:None'])
start = datetime.datetime.now()
out = sess.run(output_dict, feed_dict={input_dict["inputs"]: ("I went and saw this movie last night",)})
print(out)
print("Time all: ", datetime.datetime.now() - start)
但是golang客户端不工作:
m, err := tf.LoadSavedModel("/tmp/model/export/Servo/1506084916", []string{"serve"}, &tf.SessionOptions{})
if err != nil {
panic(fmt.Errorf("load model: %s", err))
}
data := "I went and saw this movie last night"
t, err := tf.NewTensor([]string{data})
if err != nil {
panic(fmt.Errorf("tensor err: %s", err))
}
fmt.Printf("tensor: %v", t.Shape())
output, err = m.Session.Run(
map[tf.Output]*tf.Tensor{
m.Graph.Operation("save_1/StringJoin/inputs_1").Output(0): t,
}, []tf.Output{
m.Graph.Operation("linear/binary_logistic_head/predictions/classes").Output(0),
}, nil,
)
if err != nil {
panic(fmt.Errorf("run model: %s", err))
}
我收到错误:
panic: run model: You must feed a value for placeholder tensor 'Placeholder' with dtype string and shape [?] [[Node: Placeholder = Placeholder_output_shapes=[[?]], dtype=DT_STRING, shape=[?], _device="/job:localhost/replica:0/task:0/cpu:0"]]
如何用 golang 呈现 shape=[?]
张量?或者我需要更改 python 训练脚本的输入格式?
更新:
这个字符串"save_1/StringJoin/inputs_1"
我在运行之后收到这个python-代码:
for n in sess.graph.as_graph_def().node:
if "inputs" in n.name:
print(n.name)
输出:
transform/transform/inputs/review/Placeholder
transform/transform/inputs/review/Identity
transform/transform/inputs/label/Placeholder
transform/transform/inputs/label/Identity
transform/transform_1/inputs/review/Placeholder
transform/transform_1/inputs/review/Identity
transform/transform_1/inputs/label/Placeholder
transform/transform_1/inputs/label/Identity
save_1/StringJoin/inputs_1
save_2/StringJoin/inputs_1
错误告诉您 You must feed a value for placeholder tensor 'Placeholder'
:这意味着在您为该占位符提供值之前无法构建图表。
在您的 python 代码中,您将其输入行:
input_dict["inputs"]: ("I went and saw this movie last night",)
实际上,input_dict["inputs"]
的计算结果为:<tf.Tensor 'Placeholder:0' shape=(?,) dtype=string>
。
相反,在您的 Go 代码中,您正在寻找一个名为 save_1/StringJoin/inputs_1
的张量,它不是占位符。
要遵循的规则是:在 Python 和 Go 中使用相同的输入。
要解决这个问题,因此,你只需要从图中提取名为Placeholder
的占位符(就像在python中一样),然后使用它。
m.Graph.Operation("Placeholder").Output(0): t,
此外,我建议您使用更完整且易于使用的 tensorflow 包装器 API:tfgo
还有一件事。我阅读了 TF 文档并发现了这个 topic
它有助于找到正确的 input/output 键,响应示例:
The given SavedModel SignatureDef contains the following input(s): inputs['inputs'] tensor_info:
dtype: DT_STRING
shape: (-1)
name: Placeholder:0
PS/ 发布为关注的答案