Python keras 神经网络 (Theano) 包 returns 关于数据维度的错误

Python keras neural network (Theano) package returns an error about data dimensions

我有这个代码:

import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
from sklearn import datasets
import theano

iris = datasets.load_iris()
X = iris.data[:,0:3]  # we only take the first two features.
Y = iris.target

X = X.astype(theano.config.floatX)
Y = Y.astype(theano.config.floatX)


model = Sequential()
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(150, 1, init='uniform'))
model.add(Activation('softmax'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)

model.fit(X, Y, nb_epoch=20, batch_size=150)


score = model.evaluate(X_train, y_train, batch_size=16)

Returns 这个错误:

ValueError: Shape mismatch: x has 3 cols (and 150 rows) but y has 150 rows (and 1 cols)
Apply node that caused the error: Dot22(<TensorType(float64, matrix)>, <TensorType(float64, matrix)>)
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(150L, 3L), (150L, 1L)]
Inputs strides: [(24L, 8L), (8L, 8L)]
Inputs values: ['not shown', 'not shown']

有什么问题?

您为内部图层指定了错误的输出尺寸。请参阅 Keras 文档中的示例:

model = Sequential()
model.add(Dense(20, 64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 64, init='uniform'))
model.add(Activation('tanh'))
model.add(Dropout(0.5))
model.add(Dense(64, 2, init='uniform'))
model.add(Activation('softmax'))

注意一层的输出大小如何匹配下一层的输入大小:

20x64 -> 64x64 -> 64x2

第一个数字始终是输入大小(上一层的神经元数量),第二个数字始终是输出大小(下一层的神经元数量)。所以在这个例子中你有四层:

  • 具有 20 个神经元的输入层
  • 具有 64 个神经元的隐藏层
  • 具有 64 个神经元的隐藏层
  • 具有 2 个神经元的输出层

您唯一的硬性限制是第一个(输入)层需要具有与您的特征一样多的神经元,而最后一个(输出)层需要具有与您的任务所需一样多的神经元。

对于你的例子,因为你有三个特征,你需要将输入层大小更改为 3,并且你可以保留这个例子中的两个输出神经元来进行二元分类(或者像你一样使用一个,有物流损失)。