使用 Keras 进行多维回归

Multi-dimensional regression with Keras

我想使用 Keras 训练神经网络进行二维回归。

我的输入是一个数字,我的输出有两个数字:

model = Sequential()
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,), kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
model.add(Activation('relu'))
model.add(Dense(2, kernel_initializer=initializers.constant(0.0), bias_initializer=initializers.constant(0.0)))
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
model.compile(loss='mean_squared_error', optimizer=adam)

然后我创建了一些用于训练的虚拟数据:

inputs = np.zeros((10, 1), dtype=np.float32)
targets = np.zeros((10, 2), dtype=np.float32)

for i in range(10):
    inputs[i] = i / 10.0
    targets[i, 0] = 0.1
    targets[i, 1] = 0.01 * i

最后,我在循环中使用小批量训练,同时测试训练数据:

while True:

    loss = model.train_on_batch(inputs, targets)

    test_outputs = model.predict(inputs)

    print test_outputs

问题是,打印出来的结果如下:

[0.1, 0.045]
[0.1, 0.045]
[0.1, 0.045]
.....
.....
.....

因此,虽然第一个维度是正确的 (0.1),但第二个维度不正确。第二个维度应该是 [0.01, 0.02, 0.03, .....]。所以实际上,网络的输出 (0.45) 只是第二维中所有值的平均值。

我做错了什么?

问题是,您将所有权重初始化为零。问题是,如果所有权重都相同,那么所有梯度都相同。所以就好像你有一个网络,每一层都有一个神经元。删除它以便使用默认的随机初始化并且它有效:

model = Sequential()
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(16, input_shape=(1,)))
model.add(Activation('relu'))
model.add(Dense(2))
model.compile(loss='mean_squared_error', optimizer='Adam')

1000个epoch后的结果:

Epoch 1000/1000
10/10 [==============================] - 0s - loss: 5.2522e-08

In [59]: test_outputs
Out[59]:
array([[ 0.09983768,  0.00040025],
       [ 0.09986718,  0.010469  ],
       [ 0.09985521,  0.02051429],
       [ 0.09984323,  0.03055958],
       [ 0.09983127,  0.04060487],
       [ 0.09995781,  0.05083206],
       [ 0.09995599,  0.06089856],
       [ 0.09995417,  0.07096504],
       [ 0.09995237,  0.08103154],
       [ 0.09995055,  0.09109804]], dtype=float32)