如何从量化的 TFLite 中获取 class 索引?
How to get class indices from a quantized TFLite?
我一直在使用 TensorFlow 训练量化的 Mobilenet V2,但我不知道如何从中获取 class 索引。
我正在使用 Tensorflow 1.12
以下是我的输入和输出详细信息。
Input details [{'name': 'normalized_input_image_tensor', 'index': 260, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'shape_signature': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128), 'quantization_parameters': {'scales': array([0.0078125], dtype=float32), 'zero_points': array([128], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
Output details [{'name': 'TFLite_Detection_PostProcess', 'index': 252, 'shape': array([ 1, 10, 4], dtype=int32), 'shape_signature': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 253, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 254, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 255, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
我一直在尝试通过执行以下操作来获取 class 索引:
interpreter = tf.lite.Interpreter(model_path=PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
classes = interpreter.get_tensor(output_details[1]['index'])[0]
但是,classes 索引不正确。打印时,classes
看起来像这样:[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
。我的数据集中有超过 1 个 class,所以这没有任何意义。
获取 class 索引的正确方法是什么?
尝试使用:
classes = interpreter.get_tensor(output_details[0]['index'])
经过大量实验,发现不是量化问题。我们在创建 .tflite 时使用了错误的 graph_def
.pb 文件,因此它预测 类 不存在。
我一直在使用 TensorFlow 训练量化的 Mobilenet V2,但我不知道如何从中获取 class 索引。
我正在使用 Tensorflow 1.12
以下是我的输入和输出详细信息。
Input details [{'name': 'normalized_input_image_tensor', 'index': 260, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'shape_signature': array([ 1, 300, 300, 3], dtype=int32), 'dtype': <class 'numpy.uint8'>, 'quantization': (0.0078125, 128), 'quantization_parameters': {'scales': array([0.0078125], dtype=float32), 'zero_points': array([128], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
Output details [{'name': 'TFLite_Detection_PostProcess', 'index': 252, 'shape': array([ 1, 10, 4], dtype=int32), 'shape_signature': array([ 1, 10, 4], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:1', 'index': 253, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:2', 'index': 254, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([ 1, 10], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'TFLite_Detection_PostProcess:3', 'index': 255, 'shape': array([1], dtype=int32), 'shape_signature': array([1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
我一直在尝试通过执行以下操作来获取 class 索引:
interpreter = tf.lite.Interpreter(model_path=PATH)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
classes = interpreter.get_tensor(output_details[1]['index'])[0]
但是,classes 索引不正确。打印时,classes
看起来像这样:[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
。我的数据集中有超过 1 个 class,所以这没有任何意义。
获取 class 索引的正确方法是什么?
尝试使用:
classes = interpreter.get_tensor(output_details[0]['index'])
经过大量实验,发现不是量化问题。我们在创建 .tflite 时使用了错误的 graph_def
.pb 文件,因此它预测 类 不存在。