将使用应用程序模块下载的模型(?权重)转换为 tflite

Convert models( ?weights ) downloaded using applications module to tflite

我正在尝试将使用 tf.keras 中的应用程序模块下载的 mobilenet 模型转换为 tensorflow lite 格式。我使用的 TensorFlow 版本是 1.31。我不知道模型实际上是只存储权重还是权重+架构+optimizer_state。当我尝试转换命令时:

from tensorflow import lite

lite.TFLiteConverter.from_keras_model_file( '/path/to/mobilenet_1_0_224_tf.h5' )

它导致了这个错误:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/lite/python/lite.py", line 370, in from_keras_model_file
    keras_model = _keras.models.load_model(model_file)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/saving.py", line 232, in load_model
    raise ValueError('No model found in config file.')
ValueError: No model found in config file.

据此,我假设模型只是权重。因此,我尝试使用应用程序模块加载模型,并尝试使用 model.save() 保存模型。但这导致了以下错误。

Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 300, in __init__
    fetch, allow_tensor=True, allow_operation=True))
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3478, in as_graph_element
    return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 3557, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("conv1/kernel/Read/ReadVariableOp:0", shape=(3, 3, 3, 32), dtype=float32) is not an element of this graph.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/network.py", line 1334, in save
    save_model(self, filepath, overwrite, include_optimizer)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/saving.py", line 111, in save_model
    save_weights_to_hdf5_group(model_weights_group, model_layers)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/engine/saving.py", line 742, in save_weights_to_hdf5_group
    weight_values = K.batch_get_value(symbolic_weights)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/backend.py", line 2819, in batch_get_value
    return get_session().run(tensors)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 929, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1137, in _run
    self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 471, in __init__
    self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 261, in for_fetch
    return _ListFetchMapper(fetch)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 370, in __init__
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 370, in <listcomp>
    self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 271, in for_fetch
    return _ElementFetchMapper(fetches, contraction_fn)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 307, in __init__
    'Tensor. (%s)' % (fetch, str(e)))
ValueError: Fetch argument <tf.Variable 'conv1/kernel:0' shape=(3, 3, 3, 32) dtype=float32> cannot be interpreted as a Tensor. (Tensor Tensor("conv1/kernel/Read/ReadVariableOp:0", shape=(3, 3, 3, 32), dtype=float32) is not an element of this graph.)

有谁知道这里真正的问题是什么? TIA

你是如何保存你的模型的,也许你只保存了权重而不是模型并且你正在尝试调用不存在的加载模型。

如果这不是问题,请尝试清除会话。

from keras.backend import clear_session
clear_session()

我是这样转换模型的

converter = tf.lite.TFLiteConverter.from_keras_model_file('model name')
tflite_model = converter.convert()
open("converted/model.tflite", "wb").write(tflite_model)