将 Python 2D ndarray 加载到 Android 以在 TFLite 上进行推理

Loading Python 2D ndarray into Android for inference on TFLite

我想在加载到 Android 项目的 TensorFlow Lite 模型上测试推理。

我在 Python 环境中生成了一些输入,我想保存到文件中,加载到我的 Android 应用程序中并用于 TFLite 推理。我的输入有点大,一个例子是:

<class 'numpy.ndarray'>, dtype: float32, shape: (1, 596, 80)

我需要一些方法来序列化这个 ndarray 并将其加载到 Android。

可以找到有关 TFLite 推理的更多信息 here。本质上,这应该是原始浮点数的多维数组,或 ByteBuffer。

最简单的方法是什么:

谢谢!

我终于弄明白了,有一个方便的 Java 库,名为 JavaNpy,它允许您打开 Java 中的 .npy 文件,因此 Android .

在 Python 这边,我以正常方式保存了一个扁平化的 .npy

data_flat = data.flatten()
print(data_flat.shape)
np.save(file="data.npy", arr=data_flat)

在 Android 中,我将其放入 assets 文件夹中。

然后我将它加载到 JavaNpy:

InputStream stream = context.getAssets().open("data.npy")
Npy npy = new Npy(stream);
float[] npyData = npy.floatElements();

最后将其转换为 TensorBuffer:

int[] inputShape = new int[]{1, 596, 80};   //the data shape before I flattened it
TensorBuffer tensorBuffer = TensorBuffer.createFixedSize(inputShape, DataType.FLOAT32);
tensorBuffer.loadArray(npyData);

然后我使用这个 tensorBuffer 对我的 TFLite 模型进行推理。