"The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now."

"The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now."

我正在尝试重新创建一个模型并重新生成其发布的结果。我使用 TF 2.0,我认为模型是用 Theano 后端编码的。该模型来自 github 中的 repo,用于上下文。

我也不使用 tensorflow-gpu,因为它与我的硬件设置不兼容。

无论如何,起初我在尝试加载其权重甚至模型时遇到了很多错误。当我意识到 save/load 函数只是一团糟时,我开始尝试加载和训练模型。因为我使用 Tensorflow,所以我修改了代码以使用 'channels_last' 格式,或 NHWC,正如错误所说的那样。

这是我的导入列表:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from tensorflow import keras

import cv2
import os
import pathlib
import shutil

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

修改后的模型:

def createModel():
    model = keras.models.Sequential()
    model.add(keras.layers.Lambda(norm_input, input_shape=(28,28,1), output_shape=(28,28,1)))
    model.add(keras.layers.Conv2D(32, (3,3)))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.BatchNormalization(axis=1))
    model.add(keras.layers.Conv2D(32, (3,3)))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.MaxPooling2D())
    model.add(keras.layers.BatchNormalization(axis=1))
    model.add(keras.layers.Conv2D(64, (3,3)))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.BatchNormalization(axis=1))
    model.add(keras.layers.Conv2D(64, (3,3)))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.MaxPooling2D())
    model.add(keras.layers.Flatten())
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Dense(512))
    model.add(keras.layers.LeakyReLU())
    model.add(keras.layers.BatchNormalization())
    model.add(keras.layers.Dropout(0.3))
    model.add(keras.layers.Dense(10, activation='softmax'))

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    return model

我如何加载和预处理 MNIST 数据集:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

test_labels = y_test

x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)

x_train /= 255
x_test /= 255

y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

ImageDataGenerator:

    gen = keras.preprocessing.image.ImageDataGenerator(
            rotation_range=12, 
            width_shift_range=0.1, 
            shear_range=0.3,
            height_shift_range=0.1, 
            zoom_range=0.1, 
            data_format='channels_last')

最后,训练模型的函数:

def fit_model(m):
    m.fit_generator(batches, steps_per_epoch=steps_per_epoch, epochs=1, verbose=0,
                   validation_data=test_batches, validation_steps=validation_steps)
    m.optimizer.lr = 0.1
    m.fit_generator(batches, steps_per_epoch=steps_per_epoch, epochs=4, verbose=0,
                   validation_data=test_batches, validation_steps=validation_steps)
    m.optimizer.lr = 0.01
    m.fit_generator(batches, steps_per_epoch=steps_per_epoch, epochs=12, verbose=0,
                   validation_data=test_batches, validation_steps=validation_steps)
    m.optimizer.lr = 0.001
    m.fit_generator(batches, steps_per_epoch=steps_per_epoch, epochs=18, verbose=0,
                   validation_data=test_batches, validation_steps=validation_steps)
    return m

最后一个代码片段是错误指向的地方,但我不知道那里到底出了什么问题,与图像格式有关。具体来说,它指向第三行,或以 validation_data=... 开头的行。

完整的错误是:

Component function execution failed: Internal: The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now.
     [[{{node batch_normalization_v2/cond/then/_0/FusedBatchNorm}}]]

和回溯:

Traceback (most recent call last):
  File "model3.py", line 113, in <module>
    m = fit_model(createModel())
  File "model3.py", line 52, in fit_model
    validation_data=test_batches, validation_steps=validation_steps)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1515, in fit_generator
    steps_name='steps_per_epoch')
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_generator.py", line 257, in model_iteration
    batch_outs = batch_function(*batch_data)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py", line 1259, in train_on_batch
    outputs = self._fit_function(ins)  # pylint: disable=not-callable
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/keras/backend.py", line 3217, in __call__
    outputs = self._graph_fn(*converted_inputs)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 558, in __call__
    return self._call_flat(args)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 627, in _call_flat
    outputs = self._inference_function.call(ctx, args)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 415, in call
    ctx=ctx)
  File "/home/ren/.local/lib/python3.6/site-packages/tensorflow/python/eager/execute.py", line 66, in quick_execute
    six.raise_from(core._status_to_exception(e.code, message), None)
  File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InternalError: The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now.
     [[{{node batch_normalization_v2/cond/then/_0/FusedBatchNorm}}]] [Op:__inference_keras_scratch_graph_3602]

当我在代码顶部添加 tf.keras.backend.set_image_data_format('channels_last') 行时,我希望它能得到修复。为了更好地衡量,我什至还在前面提到的 ImageDataGenerator 中加入了相同的论点。所以老实说,我不知道我错过了什么或哪里出错了。

免责声明

我只是这个领域的新手,而且这个答案只是基于我阅读文档。我不完全理解可能的影响及其对模型的影响;在最坏的情况下,它可能 影响其性能和准确性 。充其量,此解决方案只能被视为 临时补丁 抑制错误 ,至少在其他人确认之前如此。随时欢迎任何确认、澄清、建议和替代解决方案。

Keras documentationBatchNormalization

axis: Integer, the axis that should be normalized (typically the features axis). For instance, after a Conv2D layer with data_format="channels_first", set axis=1 in BatchNormalization.

由于我使用channels_lastNHWC作为图像格式,只需在参数中将axis等同于-1即可修复它。要明确:

keras.layers.BatchNormalization(axis=-1)

解决了我的问题。

如果你想在第一个轴上操作批归一化,这个小技巧会覆盖错误:

替换:model.add(keras.layers.BatchNormalization(axis=1))

与:

model.add(keras.layers.Permute((2, 3, 1)))
model.add(keras.layers.BatchNormalization(axis=-1))
model.add(keras.layers.Permute((3, 1, 2)))