将 .pb 文件转换为 .tflite
Convert .pb file to .tflite
我使用 TensorFlow 对象检测训练了自定义数据集模型 API。 Tensorflow 版本 =2.2.0。我试图使用以下代码将 .pb 文件转换为 .tflite 并收到以下错误:
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('/content/drive/MyDrive/FINAL DNET MODEL/inference_graph/saved_model') # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
错误信息:
--------------------------------------------------------------------------- Exception Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/convert.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str, debug_info_str, enable_mlir_converter)
212 debug_info_str,
--> 213 enable_mlir_converter)
214 return model_str
4 frames Exception: <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): requires element_shape to be 1D tensor during TF Lite transformation pass <unknown>:0: note: loc("StatefulPartitionedCall"): called from <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): failed to legalize operation 'tf.TensorListReserve' that was explicitly marked illegal <unknown>:0: note: loc("StatefulPartitionedCall"): called from
During handling of the above exception, another exception occurred:
ConverterError Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/convert.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str, debug_info_str, enable_mlir_converter)
214 return model_str
215 except Exception as e:
--> 216 raise ConverterError(str(e))
217
218 if distutils.spawn.find_executable(_toco_from_proto_bin) is None:
ConverterError: <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): requires element_shape to be 1D tensor during TF Lite transformation pass <unknown>:0: note: loc("StatefulPartitionedCall"): called from <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): failed to legalize operation 'tf.TensorListReserve' that was explicitly marked illegal <unknown>:0: note: loc("StatefulPartitionedCall"): called from
请帮助我解决此错误并将 .pb 文件转换为 .tflite。
This Github issue had a similar error message. This answer建议在调用map
函数时指定fn_output_signature=tf.TensorSpec(shape, dtype)
。
试试运行这样的东西
converter = tflite.TFLiteConverter.from_saved_model('PATH2model')
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter.optimizations = [tflite.Optimize.DEFAULT]
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)
我使用 TensorFlow 对象检测训练了自定义数据集模型 API。 Tensorflow 版本 =2.2.0。我试图使用以下代码将 .pb 文件转换为 .tflite 并收到以下错误:
import tensorflow as tf
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model('/content/drive/MyDrive/FINAL DNET MODEL/inference_graph/saved_model') # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
错误信息:
--------------------------------------------------------------------------- Exception Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/convert.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str, debug_info_str, enable_mlir_converter)
212 debug_info_str,
--> 213 enable_mlir_converter)
214 return model_str
4 frames Exception: <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): requires element_shape to be 1D tensor during TF Lite transformation pass <unknown>:0: note: loc("StatefulPartitionedCall"): called from <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): failed to legalize operation 'tf.TensorListReserve' that was explicitly marked illegal <unknown>:0: note: loc("StatefulPartitionedCall"): called from
During handling of the above exception, another exception occurred:
ConverterError Traceback (most recent call last) /usr/local/lib/python3.6/dist-packages/tensorflow/lite/python/convert.py in toco_convert_protos(model_flags_str, toco_flags_str, input_data_str, debug_info_str, enable_mlir_converter)
214 return model_str
215 except Exception as e:
--> 216 raise ConverterError(str(e))
217
218 if distutils.spawn.find_executable(_toco_from_proto_bin) is None:
ConverterError: <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): requires element_shape to be 1D tensor during TF Lite transformation pass <unknown>:0: note: loc("StatefulPartitionedCall"): called from <unknown>:0: error: loc(callsite(callsite("map/TensorArrayV2_1@__inference_call_func_18902" at "StatefulPartitionedCall@__inference_signature_wrapper_23056") at "StatefulPartitionedCall")): failed to legalize operation 'tf.TensorListReserve' that was explicitly marked illegal <unknown>:0: note: loc("StatefulPartitionedCall"): called from
请帮助我解决此错误并将 .pb 文件转换为 .tflite。
This Github issue had a similar error message. This answer建议在调用map
函数时指定fn_output_signature=tf.TensorSpec(shape, dtype)
。
试试运行这样的东西
converter = tflite.TFLiteConverter.from_saved_model('PATH2model')
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS
]
converter.optimizations = [tflite.Optimize.DEFAULT]
tflite_model = converter.convert()
open("model.tflite", "wb").write(tflite_model)