输入 0 与层 flatten_5 不兼容:预期 min_ndim=3,发现 ndim=2
Input 0 is incompatible with layer flatten_5: expected min_ndim=3, found ndim=2
我正在尝试微调 VGG16 神经网络,代码如下:
vgg16_model = VGG16(weights="imagenet", include_top="false", input_shape=(224,224,3))
model = Sequential()
#add fully connected layer:
model.add(Dense(256, activation='relu'))
model.add(Dense(3, activation='softmax'))
ValueError Traceback (most recent call last) in
2 model.add(vgg16_model)
3 #add fully connected layer:
----> 4 model.add(Flatten())
5 model.add(Dense(256, activation='relu'))
6 model.add(Dropout(0.5))
/usr/local/anaconda/lib/python3.6/site-packages/keras/engine/sequential.py in add(self, layer) 179 self.inputs = network.get_source_inputs(self.outputs[0])
180 elif self.outputs:
--> 181 output_tensor = layer(self.outputs[0])
182 if isinstance(output_tensor, list):
183 raise TypeError('All layers in a Sequential model '
/usr/local/anaconda/lib/python3.6/site-packages/keras/engine/base_layer.py in call(self, inputs, **kwargs)
412 # Raise exceptions in case the input is not compatible
413 # with the input_spec specified in the layer constructor.
--> 414 self.assert_input_compatibility(inputs)
416 # Collect input shapes to build layer.
/usr/local/anaconda/lib/python3.6/site-packages/keras/engine/base_layer.py in assert_input_compatibility(self, inputs)
325 self.name + ': expected min_ndim=' +
326 str(spec.min_ndim) + ', found ndim=' +
--> 327 str(K.ndim(x)))
328 # Check dtype.
329 if spec.dtype is not None:
ValueError: Input 0 is incompatible with layer flatten_5: expected min_ndim=3, found ndim=2
我尝试了很多建议的解决方案,但其中 none 可以解决我的问题。我该如何解决这个问题?
在官方 keras 网页中,在
Fine-tune InceptionV3 on a new set of classes
from keras.models import Model
vgg16_model = VGG16(weights="imagenet", include_top="false", input_shape=(224,224,3))
x = vgg16_model.output
x=Dense(256, activation='relu')(x)
predictions=Dense(3, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
您在 include_top="false"
vgg16_model = VGG16(weights="imagenet", include_top=False, input_shape=(224,224,3))