使用 deeplearning4j 加载 keras 模型时出错

Error on loading keras model with deeplearning4j

一段时间以来,我一直在努力使用 deeplearning4j 为我的 Android 应用程序加载我的 keras 神经网络模型。我已经搜索了解决方案(尽可能多),但每个解决方案都会带来新的错误,我就是无法让这个东西工作。

无论如何,我已经在 Python 中用 keras 训练了一个 NON 序列模型并像这样保存它:

model.save('model.h5')

现在我正尝试在 Android Studio 中使用 deeplearning4j 导入这个模型。我已经尝试了许多可能的变体,但这是我现在的位置:

String modelPath = new ClassPathResource("res/raw/model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath)

然而,这会触发以下错误:

java.lang.NoClassDefFoundError: Failed resolution of: Lorg/bytedeco/javacpp/hdf5;

据我了解,gradle 无法解决 org.bytedeco 的依赖项 hdf5,我同意这一点,因为我已在我的 [=52] 中排除了 hdf5-platform =] 构建,但据我所知 Android 甚至不应该支持 hdf5 (?)。

我也尝试包含 hdf5-platform 和 运行 相同的代码,但这样做会触发另一个错误:

java.lang.UnsatisfiedLinkError: Platform "android-arm64" not supported by class org.bytedeco.javacpp.hdf5

我对 gradle 概念还比较陌生,我对 Android 了解不深,但问题似乎出在我的 gradle 依赖项上。关于 deeplearning4j 的信息也有限,我也找不到替代解决方案。

我还将包括我从 this tutorial.

获得的 gradle 依赖项
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '1.0.0-beta3') {
    exclude group: 'org.bytedeco.javacpp-presets', module: 'opencv-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'leptonica-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'hdf5-platform'
    exclude group: 'org.nd4j', module: 'nd4j-base64'
}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3'
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86_64"

(如何)我应该更改我的依赖关系以使该模型导入工作?

或者我应该改变导入模型的方式吗?

您可以尝试使用TF Lite加载模型。要在 Android 甚至 iOS 中加载 TensorFlow Keras 模型,您可以使用 TensorFlow Lite.

首先,您需要将 Keras ( .h5 ) 模型转换为 TFLite 模型 ( .tflite )

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
tflite_model = converter.convert()
open( 'model.tflite' , 'wb' ).write( tflite_model )

您可以执行以下操作:

  1. 如果您的模型需要托管在将由您的应用程序下载的云源上,那么您可以使用 Firebase ML Kit. For custom TFLite models read here.

  2. 您可以将 TFLite 模型保存在应用程序的 assets 文件夹中,然后加载它的 MappedByteBuffer。 Android 的 TensorFlow Lite 依赖项可用:

    implementation ‘org.tensorflow:tensorflow-lite:2.3.0’
    

你可以参考这个codelab and this article

您可以像这样加载 MappedByteBuffer:

private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
  AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
  FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  FileChannel fileChannel = inputStream.getChannel();
  long startOffset = fileDescriptor.getStartOffset();
  long declaredLength = fileDescriptor.getDeclaredLength();
  return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}