检查输入时出错:预期 conv2d_1_input 具有形状 (28, 28, 1) 但得到形状为 (3, 224, 224) 的数组

Error when checking input: expected conv2d_1_input to have shape (28, 28, 1) but got array with shape (3, 224, 224)

我该如何解决? 下面显示了我使用的代码

这是为了将图像转换为矢量

import cv2
import numpy as np

file = cv2.imread('17316.png')
file = cv2.resize(file, (224, 224))
file = cv2.cvtColor(file, cv2.COLOR_BGR2RGB)
file = np.array(file).reshape((1, 3, 224, 224))
print(file.shape[0])

这是我应用的卷积神经网络的一部分,它导致了那个错误我该怎么办,我该如何解决它请建议我更改代码以便我可以对我的数据集进行正确的预测?

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

从错误消息中可以明显看出模型期望图像形状为 (28,28,1)。所以在将图像输入模型之前尝试调整图像大小。

file = cv2.imread('17316.png')
file = cv2.resize(file, (28, 28))
file = cv2.cvtColor(file, cv2.COLOR_BGR2GRAY)
file = file.reshape((-1, 28, 28,1))

这将解决问题。

可以post你的完整代码吗? input_shape 的值是多少?我认为您应该将其设置为 (3, 224, 224)。显然,您的 data_format 是 channels_first,根据 Keras conv2d documentation,默认值为 channels_last。所以,我建议你使用你的第一个卷积层

model.add(Conv2D(32, kernel_size = (3, 3),
                 activation = 'relu',
                 input_shape = (3, 224, 224), 
                 data_format = "channels_first")

更新:根据您的代码,以下操作应该有效,但可能不会产生您想要的结果。您正在 mnist 数据集上进行训练,该数据集需要 28x28x1 格式的图像,因此您必须调整大小,如 Mitiku 的答案所示。希望对您有所帮助。

import keras
import cv2 
import numpy as np 
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten 
from keras.layers import Conv2D, MaxPooling2D 
from keras import backend as K 

batch_size = 128 
num_classes = 10 
epochs = 1 
img_rows, img_cols = 28, 28
(x_train, y_train), (x_test, y_test) = mnist.load_data() 

if K.image_data_format() == 'channels_first': 
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) 
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) 
    input_shape = (1, img_rows, img_cols) 
else: 
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) 
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) 
    input_shape = (img_rows, img_cols, 1) 
    x_train = x_train.astype('float32')

x_test = x_test.astype('float32') 
x_train /= 255 
x_test /= 255 
print('x_train shape:', x_train.shape) 
print(x_train.shape[0], 'train samples') 
print(x_test.shape[0], 'test samples') 
y_train = keras.utils.to_categorical(y_train, num_classes) 
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential() 
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu',\
                 input_shape = input_shape))
model.add(Conv2D(64, (3, 3), activation='relu')) 
model.add(MaxPooling2D(pool_size=(2, 2))) 
model.add(Dropout(0.25)) 
model.add(Flatten()) 
model.add(Dense(128, activation='relu')) 
model.add(Dropout(0.5)) 
model.add(Dense(num_classes, activation='softmax')) 
model.compile(loss=keras.losses.categorical_crossentropy, \
              optimizer=keras.optimizers.Adadelta(), \
              metrics=['accuracy'])

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, \
          verbose=1, validation_data=(x_test, y_test)) 
score = model.evaluate(x_test, y_test, verbose=0) 

print('Test loss:', score[0]) 
print('Test accuracy:', score[1]) 

file = cv2.imread('17316.png') 
file = cv2.resize(file, (28, 28))
file = cv2.cvtColor(file, cv2.COLOR_BGR2GRAY)
file = file.reshape((28, 28,1))

model.predict(np.expand_dims(file, axis = 0))

更新 2: 对于 mnist 数据集,您有 10 classes。你有一个二进制 classification 问题。你的输出 class 将你的图像确定为 class 8,它对应于数字 7,因为 mnist 数据集 classes 是从 0 到 9 的数字。我们必须知道 classes 已编码 - 这是特定于问题的。在这种情况下,要 return 你可以做的数字:

prediction = model.predict(np.expand_dims(file, axis = 0))
prediction = np.squeeze(prediction)
index = np.where(prediction == 1)[0]
number = (index - 1).item()
print("predicted number for my image: ", number)

最后一行 returns 预测包含 1 的索引,由于索引从 1 开始,您可以从索引中减去一个以获得与您的图像对应的数字。