TFLite 模型转换显示警告并使用转换后的模型与解释器崩溃 python 内核

TFLite model conversion shows warning and using the converted model with interpreter crashes python kernel

我有一个具有以下架构的 tensorflow keras 模型:

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_3 (InputLayer)        [(None, 199)]             0         
                                                                 
 token_and_position_embeddin  (None, 199, 256)         2611456   
 g_2 (TokenAndPositionEmbedd                                     
 ing)                                                            
                                                                 
 lstm_4 (LSTM)               (None, 199, 150)          244200    
                                                                 
 lstm_5 (LSTM)               (None, 150)               180600    
                                                                 
 dense_2 (Dense)             (None, 10001)             1510151   
                                                                 
=================================================================
Total params: 4,546,407
Trainable params: 4,546,407
Non-trainable params: 0
_________________________________________________________________


我正在尝试以这种方式将其转换为 tflite:

converter = tf.lite.TFLiteConverter.from_saved_model(str(model_saved_dir))
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.target_spec.supported_types = [tf.float16]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save the model.
with open(model_home / 'model.tflite', 'wb') as f:
    f.write(tflite_model)

在转换期间它显示以下警告:

2022-04-12 20:17:23.584937: I tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 329307 microseconds.
2022-04-12 20:17:23.771051: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:237] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2022-04-12 20:17:23.983377: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:1892] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexTensorListFromTensor, FlexTensorListGetItem, FlexTensorListReserve, FlexTensorListSetItem, FlexTensorListStack
Details:
    tf.TensorListFromTensor(tensor<199x?x256xf32>, tensor<2xi32>) -> (tensor<!tf_type.variant<tensor<?x256xf32>>>) : {device = ""}
    tf.TensorListFromTensor(tensor<?x?x150xf32>, tensor<2xi32>) -> (tensor<!tf_type.variant<tensor<?x150xf32>>>) : {device = ""}
    tf.TensorListGetItem(tensor<!tf_type.variant<tensor<?x150xf32>>>, tensor<i32>, tensor<2xi32>) -> (tensor<?x150xf32>) : {device = ""}
    tf.TensorListGetItem(tensor<!tf_type.variant<tensor<?x256xf32>>>, tensor<i32>, tensor<2xi32>) -> (tensor<?x256xf32>) : {device = ""}
    tf.TensorListReserve(tensor<2xi32>, tensor<i32>) -> (tensor<!tf_type.variant<tensor<?x150xf32>>>) : {device = ""}
    tf.TensorListSetItem(tensor<!tf_type.variant<tensor<?x150xf32>>>, tensor<i32>, tensor<?x150xf32>) -> (tensor<!tf_type.variant<tensor<?x150xf32>>>) : {device = ""}
    tf.TensorListStack(tensor<!tf_type.variant<tensor<?x150xf32>>>, tensor<2xi32>) -> (tensor<?x?x150xf32>) : {device = "", num_elements = -1 : i64}
See instructions: https://www.tensorflow.org/lite/guide/ops_select
WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded

但是文件是在这之后创建的,当试图以这种方式使用这个模型时:

interpreter = tf.lite.Interpreter(model_path=str(model_home / 'model.tflite'))
interpreter.allocate_tensors()
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]

# dummy input
ip = [[0 for _ in range(198)] + [1]]
ip = np.array(ip, dtype=np.int32)

interpreter.set_tensor(input_index, ip)
interpreter.invoke()
predictions = interpreter.get_tensor(output_index)
print(predictions)

我的 jupyter notebook 内核死机并且 python 程序崩溃。我的 tensorflow 版本是“2.8.0”,python 版本是“3.8.10”。当我尝试使用来自 Android 工作室的模型时,它也会崩溃并显示错误 ByteBuffer is not a valid FlatBuffer model

我通过在转换模型之前向模型签名添加输入和输出大小,解决了 python 内核崩溃以及在 Android Studio 中加载时崩溃的问题:

run_model = tf.function(lambda x: model(x))
# This is important, let's fix the input size.
BATCH_SIZE = 1
INPUT_SIZE = 199
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([BATCH_SIZE, INPUT_SIZE], model.inputs[0].dtype))

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()