访问 Tensorflow Hub 中的权重和层

Accessing to Weights and Layers in Tensorflow Hub

当我尝试从 tensorflow-hub resporitory 获取模型时。 我可以将其视为保存模型格式,但我无法访问模型架构以及每一层的权重存储。

import tensorflow_hub as hub
model = hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")
)

有什么正式的方法可以使用它吗?

我通过model.__dict__得到的所有属性对于原始模型中的特定层都不清楚

{'_self_setattr_tracking': True,
 '_self_unconditional_checkpoint_dependencies': [TrackableReference(name='_model', ref=<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7fe4e4914710>),
  TrackableReference(name='signatures', ref=_SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>})),
  TrackableReference(name='_self_saveable_object_factories', ref=DictWrapper({}))],
 '_self_unconditional_dependency_names': {'_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,
  'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),
  '_self_saveable_object_factories': {}},
 '_self_unconditional_deferred_dependencies': {},
 '_self_update_uid': 176794,
 '_self_name_based_restores': set(),
 '_self_saveable_object_factories': {},
 '_model': <tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject at 0x7fe4e4914710>,
 'signatures': _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(input_tensor) at 0x7FE4E601F210>}),
 '__call__': <tensorflow.python.saved_model.function_deserialization.RestoredFunction at 0x7fe315a28950>,
 'graph_debug_info': ,
 'tensorflow_version': '2.4.0',
 'tensorflow_git_version': 'unknown'}

我也试过model.signatures['serving_default'].__dict__,每层的张量表示不可见

  [<tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>,
  <tf.Tensor: shape=(), dtype=resource, numpy=<unprintable>>],

使用软件包 tensorflow-serving-api 提供的 CLI 工具 saved_model_cli 可以检查保存的模型。在第一步中,我下载并缓存了模型:

from os import environ
import tensorflow_hub as hub

environ['TFHUB_CACHE_DIR'] = '/Users/you/.cache/tfhub_modules'
hub.load("https://tfhub.dev/tensorflow/centernet/hourglass_512x512/1")

然后我检查了签名和图层:

saved_model_cli show --dir /Users/you/.cache/tfhub_modules/3085eb2fbe2ad0b69801d50844c97b7a7a5ecade --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input_tensor'] tensor_info:
        dtype: DT_UINT8
        shape: (1, -1, -1, 3)
        name: serving_default_input_tensor:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['detection_boxes'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 100, 4)
        name: StatefulPartitionedCall:0
    outputs['detection_classes'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 100)
        name: StatefulPartitionedCall:1
    outputs['detection_scores'] tensor_info:
        dtype: DT_FLOAT
        shape: (1, 100)
        name: StatefulPartitionedCall:2
    outputs['num_detections'] tensor_info:
        dtype: DT_FLOAT
        shape: (1)
        name: StatefulPartitionedCall:3
  Method name is: tensorflow/serving/predict

之后,我使用调试器了解保存的模型在内部如何工作,并在存储数据(权重,.. .) 的模型。在这里您可以看到 model.signatures['serving_default'].variables:

的输出

答案的简短摘要。我们可以通过model.signatures['serving_default'].variables

访问层的变量