无法从具有 602112 字节的 Java 缓冲区复制到具有 150528 字节的 TensorFlowLite 张量 (input_1)

Cannot copy to a TensorFlowLite tensor (input_1) with 150528 bytes from a Java Buffer with 602112 bytes

我正在尝试在 tflitecamerademo 示例中使用我的模型。

这是我的模型

演示崩溃,原因如下

java.lang.IllegalArgumentException: Cannot copy to a TensorFlowLite tensor (input_1) with 150528 bytes from a Java Buffer with 602112 bytes.

我按照google的例子初始化字节缓冲区

imgData = ByteBuffer.allocateDirect(4 * DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);

imgData.order(ByteOrder.nativeOrder());

DIM_BATCH_SIZE = 1
DIM_IMG_SIZE_X = 224
DIM_IMG_SIZE_Y = 224
DIM_PIXEL_SIZE = 3

然后我将图像调整为净分辨率并将其转换为字节缓冲区

Bitmap reshapeBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, false);
convertBitmapToByteBuffer(reshapeBitmap);


private void convertBitmapToByteBuffer(Bitmap bitmap) {
        if (imgData == null) {
            return;
        }
        imgData.rewind();
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        // Convert the image to floating point.
        int pixel = 0;
        long startTime = SystemClock.uptimeMillis();
        for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
            for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
                final int val = intValues[pixel++];
                imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
                imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
            }
        }
        long endTime = SystemClock.uptimeMillis();
        //Log.d("Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
    }

最后,我运行检测

tflite.run(imgData, labelProbArray);

此处发生崩溃是由于输入大小与缓冲区大小不同。

现在,如果我们手动乘以 1 * 224 * 224 * 3 * 4,我们将得到 602112,这是正确的大小。为什么我的代码缺少最后一个乘法。

这是类型不匹配造成的。
根据模型描述,你有整数类型input/output,可能是量化模型。
您正在尝试准备要提供的浮点数据缓冲区。有 2 种最常见的解决方案:

1)准备uint8数据。将位图像素作为 1 字节 uint8 写入字节缓冲区:

imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); // now buffer size and input size match

imgData.order(ByteOrder.nativeOrder());

Bitmap reshapeBitmap = Bitmap.createScaledBitmap(bitmap, 224, 224, false);
convertBitmapToByteBuffer(reshapeBitmap);


private void convertBitmapToByteBuffer(Bitmap bitmap) {
        if (imgData == null) {
            return;
        }
        imgData.rewind();
        bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
        // Convert the image to floating point.
        int pixel = 0;
        long startTime = SystemClock.uptimeMillis();
        for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
            for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
                final int val = intValues[pixel++];
                imgData.putChar((byte)((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD*255));
                imgData.putChar((byte)((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD*255));
                imgData.putChar((byte)((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD*255));
            }
        }
        long endTime = SystemClock.uptimeMillis();
        //Log.d("Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime));
    }

另请查看 tflite support lib,它简化了您处理 input/output 数据的方式,可能会派上用场。

2)查找具有浮点输入的模型并使用您的代码

我们可以只使用 ImageProcessor CastOp(DataType.UINT8) 将位图转换为 uint8。

ImageProcessor imageProcessor;
TensorImage xceptionTfliteInput;
if(IS_INT8){
    imageProcessor =
           new ImageProcessor.Builder()
                            .add(new ResizeOp(INPNUT_SIZE.getHeight(), INPNUT_SIZE.getWidth(), ResizeOp.ResizeMethod.BILINEAR))
                            .add(new NormalizeOp(0, 255))
                            .add(new QuantizeOp(inputQuantParams.getZeroPoint(), inputQuantParams.getScale()))
                            .add(new CastOp(DataType.UINT8))
                            .build();
    xceptionTfliteInput = new TensorImage(DataType.UINT8);
} else {
    imageProcessor =
          new ImageProcessor.Builder()
                            .add(new ResizeOp(INPNUT_SIZE.getHeight(), INPNUT_SIZE.getWidth(), ResizeOp.ResizeMethod.BILINEAR))
                            .add(new NormalizeOp(0, 255))
                            .build();
     xceptionTfliteInput = new TensorImage(DataType.FLOAT32);
}
xceptionTfliteInput.load(bitmap);
xceptionTfliteInput = imageProcessor.process(xceptionTfliteInput);