如何检索导出函数的输出?

How to retrieve output of exported function?

从导出的 TensorFlow 模型调用函数时,我收到两个字符串(“output_0”、“output_1”)而不是实际模型。我如何才能获取与此字符串关联的张量以访问输出?

导出模型:

class OneStep(tf.keras.Model):
  ...
  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string), tf.TensorSpec(shape=[None, None], dtype=tf.float32)])
  def generate_one_step(self, inputs, states):

  ...
  @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
  def generate_one_step_none(self, inputs):

...
tf.saved_model.save(one_step_model, 'one_Step', signatures={ "generate_one_step": one_step_model.generate_one_step, "generate_one_step_none": one_step_model.generate_one_step_none})

要导入的代码:

one_step = tf.saved_model.load('one_step')
step_gen = one_step.signatures["generate_one_step"]
step_gen_none = one_step.signatures["generate_one_step_none"]

next_char = tf.constant(['Test'], tf.string)

a, b = step_gen_none(inputs=next_char)
print(a,b) # returns "input_0", "input_1"

事实证明,您必须存储函数调用的结果,然后将其作为数组访问。正确答案是

res = step_gen_none(inputs=next_char)
a = res["output_0"]
b = res["output_1"]

名称也可以更改为描述的不太通用的名称here. I found the answer in a migration guide.虽然与答案无关,但值得指出的是 Tensorflow 2 没有很好的文档记录,您通常不应信任任何 Tensorflow 来源在网络上,除非他们明确提到 v2.