浮动物体检测模型的变量输入和输出形式应该是什么

what should be the input and output forom of variables for float object detection model

https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_ssd_tflite_v1.zip

我正在制作一个 android 具有 gpu 委托支持的对象检测应用程序。 上面的link是针对tensorflow lite对象检测float模型的。 没有可用的文档。我想知道此 tflite 模型的变量的输入和输出形式,以便我可以将其提供给解释器以进行 gpu 委托。 提前致谢!

我使用合作实验室。所以我使用下面的代码来确定输入和输出:

import tensorflow as tf
interpreter = tf.lite.Interpreter('mobilenet_ssd.tflite')
print(interpreter.get_input_details())
print(interpreter.get_output_details())

所以解压文件夹,找到文件并用上面的代码加载它。我用上面的代码做到了,结果是:

[{'name': 'Preprocessor/sub', 'index': 165, 'shape': array([ 1, 300, 300, 3], dtype=int32), 'shape_signature': array([ 1, 300, 300, 3], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales':数组([],dtype=float32),'zero_points':数组([],dtype=int32),'quantized_dimension':0},'sparsity_parameters':{}}]

[{'name': 'concat', 'index': 172, 'shape': array([ 1, 1917, 4], dtype=int32), 'shape_signature': array([ 1, 1917, 4], dtype=int32), 'dtype': , 'quantization': (0.0, 0), 'quantization_parameters': {'scales' : 数组([], dtype=float32), 'zero_points': 数组([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'concat_1', 'index': 173, 'shape': 数组([ 1, 1917, 91], dtype=int32), 'shape_signature': 数组([ 1, 1917 , 91], dtype=int32), 'dtype':, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32) , 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]

也可以在 android 内进行:

// Initialize interpreter
@Throws(IOException::class)
private suspend fun initializeInterpreter(app: Application) = withContext(Dispatchers.IO) {
    // Load the TF Lite model from asset folder and initialize TF Lite Interpreter without NNAPI enabled.
    val assetManager = app.assets
    val model = loadModelFile(assetManager, "mobilenet_ssd.tflite")
    val options = Interpreter.Options()
    options.setUseNNAPI(false)
    interpreter = Interpreter(model, options)
    // Reads type and shape of input and output tensors, respectively.
    val imageTensorIndex = 0
    val inputShape: IntArray =
        interpreter.getInputTensor(imageTensorIndex).shape() // {1, length}
    Log.e("INPUT_TENSOR_WHOLE", Arrays.toString(inputShape))
    val imageDataType: DataType =
        interpreter.getInputTensor(imageTensorIndex).dataType()
    Log.e("INPUT_DATA_TYPE", imageDataType.toString())

    //modelInputSize indicates how many bytes of memory we should allocate to store the input for our TensorFlow Lite model.
    //FLOAT_TYPE_SIZE indicates how many bytes our input data type will require. We use float32, so it is 4 bytes.
    //PIXEL_SIZE indicates how many color channels there are in each pixel. Our input image is a colored image, so we have 3 color channel.
    inputImageWidth = inputShape[1]
    inputImageHeight = inputShape[2]
    modelInputSize = FLOAT_TYPE_SIZE * inputImageWidth *
            inputImageHeight * PIXEL_SIZE

    val probabilityTensorIndex = 0
    outputShape =
        interpreter.getOutputTensor(probabilityTensorIndex).shape()// {1, NUM_CLASSES}
    Log.e("OUTPUT_TENSOR_SHAPE", outputShape.contentToString())
    val probabilityDataType: DataType =
        interpreter.getOutputTensor(probabilityTensorIndex).dataType()
    Log.e("OUTPUT_DATA_TYPE", probabilityDataType.toString())
    isInitialized = true
    Log.e(TAG, "Initialized TFLite interpreter.")


    // Inputs outputs
    /*val inputTensorModel: Int = interpreter.getInputIndex("input_1")
    Log.e("INPUT_TENSOR", inputTensorModel.toString())*/

}

@Throws(IOException::class)
private fun loadModelFile(assetManager: AssetManager, filename: String): MappedByteBuffer {
    val fileDescriptor = assetManager.openFd(filename)
    val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
    val fileChannel = inputStream.channel
    val startOffset = fileDescriptor.startOffset
    val declaredLength = fileDescriptor.declaredLength
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}

如果您需要任何帮助,请标记我。