TensorRT 的 PluginFormat 必须是 kNCHW?

TensorRT's PluginFormat has to be kNCHW?

我的Tensorflow模型如下(模型的一部分)。

Tensorflow 模型需要采用 NHWC 格式才能输入图像及其处理。 Tensorflow模型转TensorRT引擎,upsample需要实现Plugin。

但是 TensorRT 插件需要格式必须是 PluginFormat::kNCHW。 如果设置为PluginFormat::kNHWC,插件无法编译。

那么如何为这样的 Tensorflow 模型创建插件呢?

是的,TensorRT 插件需要是 NCHW 格式。要使用 NHWC 格式的 Tensorflow 模型,处理部分(例如 CUDA 代码中的 运行 部分)需要设计为处理 NCHW 格式的输入数组。然后,如果 Tensorflow 模型是 NHWC 格式,则在插件的输出中重新格式化回 NHWC。