从 Tensorflow 中的 Unity ml-agents 推断预训练的 ONNX 模型

Inference on pre-trained ONNX model from Unity ml-agents in Tensorflow

我有一个来自 Unity 的 ml-agents 的预训练模型。现在,我正在尝试使用 TensorFlow 在 python 中对该模型进行推理。为此,我使用 TensorFlow Backend for ONNX 将 ONNX 模型保存为 SavedModel,以便稍后加载该模型。用于保存模型的代码是

import onnx
from onnx_tf.backend import prepare

onnx_model = onnx.load(model_path)  # load onnx model
tf_rep = prepare(onnx_model, logging_level='DEBUG')
tf_rep.export_graph(output_path)

加载模型的代码和运行一个测试示例

imported = tf.saved_model.load(
     model_dir, tags=None, options=None
)
f = imported.signatures["serving_default"]
print(f(visual_observation_0=tf.cast(forward, tf.float32), 
          visual_observation_1=tf.cast(body, tf.float32)))

现在有几个问题。

  1. 测试的输出有 6 个输出值。 (有关 ONNX 文件的可视化图表,请参见下图)
  2. 我在尝试保存模型时收到以下消息(请参阅下面的调试信息)

不确定这里发生了什么任何帮助非常感谢

2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Unknown op Celu in domain 'ai.onnx'.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of ConcatFromSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Unknown op ConstantFill in domain 'ai.onnx'.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of ConvInteger in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of CumSum in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of DequantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,267 - onnx-tf - DEBUG - Fail to get since_version of Det in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of DynamicQuantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op Einsum in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of GatherElements in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of GatherND in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op GreaterOrEqual in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op ImageScaler in domain 'ai.onnx'.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Fail to get since_version of IsInf in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,268 - onnx-tf - DEBUG - Unknown op LessOrEqual in domain 'ai.onnx'.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of MatMulInteger in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Mod in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of NonMaxSuppression in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QLinearConv in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QLinearMatMul in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of QuantizeLinear in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Range in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,269 - onnx-tf - DEBUG - Fail to get since_version of Resize in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ReverseSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of RoiAlign in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of Round in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ScatterElements in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of ScatterND in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceAt in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceConstruct in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceEmpty in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceErase in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceInsert in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SequenceLength in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,270 - onnx-tf - DEBUG - Fail to get since_version of SplitToSequence in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03,271 - onnx-tf - DEBUG - Fail to get since_version of ThresholdedRelu in domain '' with max_inclusive_version=9. Set to 1.
2021-03-24 17:52:03.273323: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-03-24 17:52:03.286901: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f912d05cf60 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-03-24 17:52:03.286913: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2021-03-24 17:52:07.450878: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.

好的,原来网络的输出也给出了版本等参数。

onnx_model = onnx.load(model_path)  # load onnx model
tf_rep = prepare(onnx_model)

print(tf_rep.inputs) # Input nodes to the model
> output: ['visual_observation_0', 'visual_observation_1']
print(tf_rep.outputs) # Output nodes from the model
> output: ['version_number', 'memory_size', 'continuous_actions', 'continuous_action_output_shape', 'action', 'is_continuous_control', 'action_output_shape']

输入符合我的预期。然而输出也有版本号、内存等。我只对 continuous_actions 感兴趣。我还必须将图像缩放到 [0, 1]