生成的图像不完整 - 无法复制到带字节的 TensorFlowLite 张量 (input_1)

Image produced is incomplete - Cannot copy to a TensorFlowLite tensor (input_1) with bytes

我正在尝试加载一个 tflite 模型并 运行 它在图像上。

我的 tflite 模型具有您在图片中看到的尺寸。

现在,我收到:

Cannot copy to a TensorFlowLite tensor (input_1) with 49152 bytes from a Java Buffer with 175584 bytes.

我不明白如何处理输入和输出张量大小。现在,我正在使用输入图像大小进行初始化,输出图像大小将为 input * 4.

我必须在什么时候“添加”1 * 64 * 64 * 3 维度,因为我需要操纵每个输入图像大小?

 try {
                    tflitemodel = loadModelFile()
                    tflite = Interpreter(tflitemodel, options)
                } catch (e: IOException) {
                    Log.e(TAG, "Fail to load model", e)
                }

                val imageTensorIndex = 0
                val imageShape: IntArray =
                    tflite.getInputTensor(imageTensorIndex).shape()
                val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()
                // Build a TensorImage object
                var inputImageBuffer = TensorImage(imageDataType);

                // Load the Bitmap
                inputImageBuffer.load(bitmap)

                // Preprocess image
                val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(inputImageBuffer.height,
                        inputImageBuffer.width,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                    //.add(NormalizeOp(127.5f, 127.5f))
                    //.add(QuantizeOp(128.0f, 1 / 128.0f))
                    .build()

                // Process the image
                val processedImage = imgprocessor.process(inputImageBuffer)

                // Access the buffer ( byte[] ) of the processedImage
                val imageBuffer = processedImage.buffer
                val imageTensorBuffer = processedImage.tensorBuffer

                // output result
                val outputImageBuffer = TensorBuffer.createFixedSize(
                    intArrayOf( inputImageBuffer.height * 4 ,
                        inputImageBuffer.width * 4 ) ,
                    DataType.FLOAT32 )

                // Normalize image
                val tensorProcessor = TensorProcessor.Builder()
                    // Normalize the tensor given the mean and the standard deviation
                    .add( NormalizeOp( 127.5f, 127.5f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()
                val processedOutputTensor = tensorProcessor.process(outputImageBuffer)


                tflite.run(imageTensorBuffer.buffer, processedOutputTensor.buffer)

我试图将输出张量转换为 FLOAT32 或 UINT8。

更新

我也试过这个:

 try {
         tflitemodel = loadModelFile()
         tflite = Interpreter(tflitemodel, options)
      } catch (e: IOException) {

          Log.e(TAG, "Fail to load model", e)
        }

 val imageTensorIndex = 0
 val imageDataType: DataType = tflite.getInputTensor(imageTensorIndex).dataType()

 val imgprocessor = ImageProcessor.Builder()
                    .add(ResizeOp(64,
                                 64,
                        ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                        )
                    .add( NormalizeOp( 0.0f, 255.0f ) )
                    .add( CastOp( DataType.FLOAT32 ) )
                    .build()

 val inpIm = TensorImage(imageDataType)
 inpIm.load(bitmap)

 val processedImage = imgprocessor.process(inpIm)

 val output = TensorBuffer.createFixedSize(
                        intArrayOf(
                            124 * 4,
                            118 * 4,
                            3,
                            1
                        ),
                        DataType.FLOAT32
                    )

 val tensorProcessor = TensorProcessor.Builder()
                        
                        .add( NormalizeOp( 0.0f, 255.0f ) )
                        .add( CastOp( DataType.FLOAT32 ) )
                        .build()

 val processedOutputTensor = tensorProcessor.process(output)


 tflite.run(processedImage.buffer, processedOutputTensor.buffer)

产生:

请注意,我用作输入的当前图像具有 124 * 118 * 3 尺寸。

输出图像将具有 (124 * 4) * (118 * 4) * 3 尺寸。

模型需要 64 * 64 * 3 作为输入层。

我看了你的项目,你的class会是这样的:

class MainActivity : AppCompatActivity() {


    private val TAG = "SuperResolution"
    private val MODEL_NAME = "model_edsr.tflite"
    private val LR_IMAGE_HEIGHT = 24
    private val LR_IMAGE_WIDTH = 24
    private val UPSCALE_FACTOR = 4
    private val SR_IMAGE_HEIGHT = LR_IMAGE_HEIGHT * UPSCALE_FACTOR
    private val SR_IMAGE_WIDTH = LR_IMAGE_WIDTH * UPSCALE_FACTOR

    private lateinit var photoButton: Button
    private lateinit var srButton: Button
    private lateinit var colorizeButton: Button
    private var FILE_NAME = "photo.jpg"

    private lateinit var filename:String
    private var resultImg: Bitmap? = null

    private lateinit var gpuSwitch: Switch

    private lateinit var tflite: Interpreter
    private lateinit var tflitemodel: ByteBuffer

    private val INPUT_SIZE: Int = 96
    private val PIXEL_SIZE: Int = 3
    private val IMAGE_MEAN = 0
    private val IMAGE_STD = 255.0f


    private var bitmap: Bitmap? = null
    private var bitmapResult: Bitmap? = null

    /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as input/output  */
    private lateinit var imgDataInput: ByteBuffer
    private lateinit var imgDataOutput: ByteBuffer

    /** Dimensions of inputs.  */
    private val DIM_BATCH_SIZE = 1

    private val DIM_PIXEL_SIZE = 3

    private val DIM_IMG_SIZE_X = 64
    private val DIM_IMG_SIZE_Y = 64
    private lateinit var catBitmap: Bitmap
    /* Preallocated buffers for storing image data in. */
    private val intValues = IntArray(DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y)
    private lateinit var superImage: ImageView

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)
        superImage = findViewById(R.id.super_resolution_image)

        //val assetManager = assets
        catBitmap = getBitmapFromAsset("cat.png")


        srButton = findViewById(R.id.super_resolution)
        srButton.setOnClickListener { view: View ->

            val intent = Intent(this, SelectedImage::class.java)
            getImageResult.launch(intent)
        }


    }

    private fun getBitmapFromAsset(filePath: String?): Bitmap {
        val assetManager = assets
        val istr: InputStream
        var bitmap: Bitmap? = null
        try {
            istr = assetManager.open(filePath!!)
            bitmap = BitmapFactory.decodeStream(istr)
        } catch (e: IOException) {
            // handle exception
            Log.e("Bitmap_except", e.toString())

        }

        if (bitmap != null) {
            bitmap = Bitmap.createScaledBitmap(bitmap,64,64,true)
        }

        return bitmap?: Bitmap.createBitmap(10, 10, Bitmap.Config.ARGB_8888)
    }

    private val getImageResult =
        registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result ->
            if (result.resultCode == Activity.RESULT_OK) {
                var theImageUri: Uri? = null
                theImageUri = result.data?.getParcelableExtra<Uri>("imageuri")

                filename = "SR_" + theImageUri?.getOriginalFileName(this).toString()

                bitmap = uriToBitmap(theImageUri!!)!!//catBitmap//
                Log.v("width", bitmap!!.width.toString())

                if (bitmap != null) {
                    // call DL
                    val options = Interpreter.Options()
                    options.setNumThreads(5)
                    options.setUseNNAPI(true)
                    try {
                        tflitemodel = loadModelFile()
                        tflite = Interpreter(tflitemodel, options)
                        val index = tflite.getInputIndex("input_1")
                        tflite.resizeInput(
                            index,
                            intArrayOf(1, bitmap!!.width, bitmap!!.height, 3)
                        )
                    } catch (e: IOException) {
                        Log.e(TAG, "Fail to load model", e)
                    }

                    val imgprocessor = ImageProcessor.Builder()
                        .add(
                           ResizeOp(bitmap!!.width,
                                bitmap!!.height,
                                ResizeOp.ResizeMethod.NEAREST_NEIGHBOR)
                        )
                        .add( CastOp( DataType.FLOAT32 ) )
                        .build()

                    val inpIm = TensorImage(DataType.FLOAT32)
                    inpIm.load(bitmap)

                    // Process the image
                    val processedImage = imgprocessor.process(inpIm)

                    val output2 = Array(1) { Array(4*bitmap!!.width) { Array(4*bitmap!!.height) { FloatArray(3) } } }

                    tflite.run(processedImage.buffer, output2)

                    bitmapResult = convertArrayToBitmap(output2, 4*bitmap!!.height, 4*bitmap!!.width)

                    Log.v("widthHR", bitmapResult!!.height.toString())
                    superImage.setImageBitmap(bitmapResult)

                }
            }
        }


    @Throws(IOException::class)
    private fun loadModelFile(): MappedByteBuffer {
        val fileDescriptor = assets.openFd(MODEL_NAME)
        val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
        val fileChannel = inputStream.channel
        val startOffset = fileDescriptor.startOffset
        val declaredLength = fileDescriptor.declaredLength
        return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
    }


    private fun uriToBitmap(selectedFileUri: Uri): Bitmap? {
        try {
            val parcelFileDescriptor = contentResolver.openFileDescriptor(selectedFileUri, "r")
            val fileDescriptor: FileDescriptor = parcelFileDescriptor!!.fileDescriptor
            val image = BitmapFactory.decodeFileDescriptor(fileDescriptor)
            parcelFileDescriptor.close()
            return image
        } catch (e: IOException) {
            e.printStackTrace()
        }
        return null
    }

    private fun getOutputImage(output: ByteBuffer): Bitmap? {
        output.rewind()
        val outputWidth = 124 * 4
        val outputHeight = 118 * 4
        val bitmap = Bitmap.createBitmap(outputWidth, outputHeight, Bitmap.Config.ARGB_8888)
        val pixels = IntArray(outputWidth * outputHeight)
        for (i in 0 until outputWidth * outputHeight) {
            val a = 0xFF
            val r = output.float * 255.0f
            val g = output.float * 255.0f
            val b = output.float * 255.0f
            pixels[i] = a shl 24 or (r.toInt() shl 16) or (g.toInt() shl 8) or b.toInt()
        }
        bitmap.setPixels(pixels, 0, outputWidth, 0, 0, outputWidth, outputHeight)
        return bitmap
    }

    // save bitmap image to gallery
    private fun saveToGallery(context: Context, bitmap: Bitmap, albumName: String) {
        //val filename = "${System.currentTimeMillis()}.png"
        val write: (OutputStream) -> Boolean = {
            bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
        }

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
            val contentValues = ContentValues().apply {
                put(MediaStore.MediaColumns.DISPLAY_NAME, filename)
                put(MediaStore.MediaColumns.MIME_TYPE, "image/png")
                put(MediaStore.MediaColumns.RELATIVE_PATH, "${Environment.DIRECTORY_DCIM}/$albumName")
            }

            context.contentResolver.let {
                it.insert(MediaStore.Images.Media.EXTERNAL_CONTENT_URI, contentValues)?.let { uri ->
                    it.openOutputStream(uri)?.let(write)
                }
            }
        } else {
            val imagesDir = Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DCIM).toString() + File.separator + albumName
            val file = File(imagesDir)
            if (!file.exists()) {
                file.mkdir()
            }
            val image = File(imagesDir, filename)
            write(FileOutputStream(image))
        }
    }

    // get the filename from an image uri
    private fun Uri.getOriginalFileName(context: Context): String? {
        return context.contentResolver.query(this,
            null,
            null,
            null,
            null)?.use {
            val nameColumnIndex = it.getColumnIndex(OpenableColumns.DISPLAY_NAME)
            it.moveToFirst()
            it.getString(nameColumnIndex)
        }
    }
    fun convertArrayToBitmap(
        imageArray: Array<Array<Array<FloatArray>>>,
        imageWidth: Int,
        imageHeight: Int
    ): Bitmap {

        val conf = Bitmap.Config.ARGB_8888 // see other conf types
        val bitmap = Bitmap.createBitmap(imageWidth, imageHeight, conf)

        for (x in imageArray[0].indices) {
            for (y in imageArray[0][0].indices) {
                // Create bitmap to show on screen after inference
                val color = Color.rgb(
                    (imageArray[0][x][y][0]).toInt(),
                    (imageArray[0][x][y][1]).toInt(),
                    (imageArray[0][x][y][2]).toInt()
                )

                // this y, x is in the correct order!!!
                bitmap.setPixel(y, x, color)
            }
        }
        return bitmap
    }

}

看看我们如何在 android 中调整模型输入的大小,我们如何创建输入缓冲区和输出数组,以及我们如何将生成的数组转换为位图。对于这些程序,请检查您是否可以使用 phone 的 Gpu 以获得 x3 速度,当然在官方 documentation.

上有很多值得阅读的内容