keras顺序模型的正确输入和输出形状

proper input and output shape of a keras Sequential model

我正在尝试 运行 一个 Keras 顺序模型,但无法获得适合模型训练的正确形状。

我将 xy 重塑为:

x = x.reshape(len(x), 500)
y = y.reshape(len(y), 500)

目前,输入形状和输出形状均为:

(9766, 500)
(9766, 500)

数据集分别由 9766 个输入和 9766 个输出组成。每个输入都是一个包含 500 个值的数组,每个输出也是一个包含 500 个值的数组。

所以这是一个单一的输入数组:

[0.99479668 0.99477965 0.99484778 0.99489887 0.99483926 0.99451565
 0.99458378 0.99457526 0.99453268 0.99468597 0.99466042 0.99449862
 0.99453268 0.99454971 0.99463487 0.99461784 0.99451565 0.99463487
 0.99467745 0.99502661 0.99480519 0.99493294 0.99493294 0.99522248
 0.99526506 0.99528209 0.99527358 0.99515435 0.99529913 0.99488184
 0.99508623 0.99512881 0.99522248 0.99497552 0.9954439  0.99554609
 0.99581861 0.99573345 0.9957079  0.99626144 0.99626144 0.99592932
 0.99558867 0.99541835 0.99524803 0.99586119 0.99601448 0.99588674
 0.99584416 0.99559719 0.995495   0.99520545 0.99552055 0.99510326
 0.9951799  0.99560571 0.99561422 0.99541835 0.99586119 0.995759
 0.9957079  0.99583564 0.9959208  0.99578454 0.99604854 0.99612519
 0.99609112 0.99630402 0.9961337  0.99672983 0.99655099 0.99643176
 0.99643176 0.99648286 0.99649138 0.99645731 0.99670428 0.99654247
 0.99647435 0.99607409 0.99589525 0.99600596 0.99596338 0.99621035
 0.99633809 0.99632106 0.99583564 0.99581009 0.99574196 0.9959719
 0.99557164 0.99567383 0.99572493 0.9958697  0.99568235 0.9959208
 0.99598893 0.99620183 0.99611667 0.99620183 0.9959719  0.9957079
 0.99612519 0.99558867 0.99569938 0.99518842 0.99553758 0.99552055
 0.99576751 0.99577603 0.99583564 0.99602299 0.99630402 0.99637215
 0.99701937 0.99701086 0.99731744 0.99700234 0.99696828 0.99668725
 0.99703641 0.99725782 0.99684054 0.99605706 0.99608261 0.99581861
 0.9958697  0.99583564 0.99566532 0.99585267 0.99566532 0.99604003
 0.99540984 0.99473707 0.995231   0.99441346 0.9942261  0.99397914
 0.99367256 0.99409836 0.99415797 0.99420907 0.99398765 0.99356185
 0.99382585 0.99428571 0.9945412  0.99444752 0.99436236 0.99404726
 0.9938003  0.99424313 0.99483074 0.99474558 0.99457526 0.99457526
 0.99465191 0.99466042 0.99467745 0.99448158 0.99454971 0.99479668
 0.994703   0.99455823 0.99472855 0.99507771 0.99529913 0.99515435
 0.99525655 0.99621886 0.99586119 0.99576751 0.9962359  0.99614222
 0.99723228 0.99685757 0.99680647 0.99689163 0.99644028 0.99701937
 0.99675538 0.99637215 0.99614222 0.99628699 0.9964488  0.99641473
 0.99652544 0.99652544 0.99664467 0.99698531 0.99712157 0.99703641
 0.99799872 0.99859485 0.99876517 0.99950607 0.99902065 0.99891846
 0.99804982 0.99839898 0.99857782 0.99850117 0.99891846 0.99912284
 0.99919097 0.99919949 0.99896956 0.99896104 0.99877369 0.99898659
 0.99918246 0.99890994 0.9990462  0.99895252 0.99885033 0.99871407
 0.99871407 0.99871407 0.99864594 0.99854375 0.9983564  0.9985693
 0.99870556 0.99868001 0.9987822  0.99877369 0.99900362 0.99882478
 0.99896956 0.99885885 0.99880775 0.99890994 0.99906323 0.99908026
 0.9990462  0.99921652 0.99920801 0.99936129 0.99937833 0.99943794
 0.99935278 0.99943794 0.99967639 0.99956568 0.99960826 0.99962529
 0.99942942 0.99940387 0.9992591  0.99908878 0.99912284 0.99913988
 0.99905472 0.99914839 0.99913136 0.99933575 0.99935278 0.99929317
 0.99931871 0.99905472 0.99965084 0.99995742 1.         0.99962529
 0.999472   0.99939536 0.99932723 0.99929317 0.99931871 0.99931871
 0.99950607 0.99953162 0.99942942 0.99919097 0.99902917 0.99913988
 0.99915691 0.9990462  0.9990973  0.99923355 0.99940387 0.99954865
 0.99958271 0.99940387 0.99943794 0.99928465 0.9990973  0.99905472
 0.99915691 0.99921652 0.99913988 0.99913136 0.99912284 0.9992591
 0.99916542 0.99917394 0.99918246 0.99906323 0.99905472 0.99907175
 0.99901214 0.9990462  0.99913988 0.9990462  0.9990462  0.99880775
 0.99890994 0.99868852 0.99868852 0.99889291 0.99896956 0.99886736
 0.99932723 0.99943794 0.99932723 0.99931871 0.99931871 0.99921652
 0.99874814 0.99871407 0.99915691 0.99969342 0.99962529 0.99916542
 0.99902917 0.99887588 0.99919097 0.99943794 0.99847562 0.9988333
 0.99905472 0.99913988 0.99931871 0.99936129 0.99893549 0.99869704
 0.99842453 0.99868001 0.99868852 0.9987822  0.9987311  0.99871407
 0.99860336 0.99826272 0.99805834 0.99785395 0.99792208 0.99804982
 0.99797317 0.99797317 0.99778582 0.99749627 0.99751331 0.99758143
 0.99732595 0.99741111 0.99699383 0.99733447 0.99728337 0.99686608
 0.99714712 0.9973515  0.99753885 0.99753034 0.99762402 0.99774324
 0.99781989 0.99765808 0.99739408 0.9974026  0.99723228 0.99737705
 0.99728337 0.99728337 0.99736002 0.99726634 0.99732595 0.99721524
 0.99728337 0.99701937 0.99715563 0.99715563 0.99744518 0.99753034
 0.99747073 0.99765808 0.9978284  0.99726634 0.99724931 0.99776879
 0.99746221 0.9976666  0.9976666  0.99744518 0.99734298 0.99833085
 0.99866298 0.99800724 0.99714712 0.99648286 0.99588674 0.99598041
 0.99563125 0.99595486 0.99626144 0.99601448 0.99456674 0.9947541
 0.99499255 0.99483926 0.9950181  0.99497552 0.99484778 0.99424313
 0.99416649 0.99416649 0.9942772  0.99288908 0.99266766 0.99293166
 0.99248031 0.99312753 0.99269321 0.99307643 0.99286353 0.99319566
 0.99346817 0.99337449 0.99322972 0.99302534 0.99322121 0.99307643
 0.99295721 0.99344262 0.99262508 0.99259953 0.99246327 0.99254844
 0.99265063 0.99288908 0.99288908 0.9930594  0.9933234  0.99340004
 0.99320417 0.99331488 0.99319566 0.99335746 0.99322121 0.99271876
 0.99271024 0.99270172 0.99259102 0.99308495 0.99331488 0.9930083
 0.99285501 0.99289759 0.99276134 0.99259102 0.99266766 0.99221631
 0.99216521 0.99225889 0.99227592 0.99196934 0.99162018 0.99147541
 0.99134767 0.99159463 0.99152651 0.99166276 0.99169683 0.99168831
 0.99175644 0.99178199 0.99161167 0.99165425 0.99170534 0.9915776
 0.9915776  0.99144135 0.99169683 0.99170534 0.99144986 0.99170534
 0.99187567 0.99192676 0.99183308 0.99177347 0.99173941 0.99176496
 0.99170534 0.9917905  0.99178199 0.99144986 0.99147541 0.99142431
 0.99149244 0.99139877]

这是一个输出数组:

[0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.99449862
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.99731744 0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.99356185
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         1.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.99686608
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.99866298 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.99134767 0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.         0.         0.         0.         0.
 0.         0.        ]

这是我正在尝试训练数据的模型(很可能架构不佳):

model = Sequential()
model.add(LSTM(128, input_shape=(x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))

model.add(LSTM(128, input_shape=(x.shape[1:])))
model.add(Dropout(0.1))

model.add(LSTM(32, input_shape=(x.shape[1:]) ,activation = 'relu'))
model.add(Dropout(0.2))

model.add (Dense(1 ,activation = 'sigmoid'))
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-3)
model.compile(loss='mse',optimizer=opt, metrics=['accuracy'])
model.fit(x,y,epochs=20,validation_split=0.20)

我希望模型训练的方式是查看输入并生成一个包含 500 个值的数组,如上所示的输出数组。 但无论我尝试什么形状,我都会收到如下错误:

ValueError: Input 0 of layer "lstm" is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 500)

这里的正确形状是什么形状,我在模型架构上做错了什么?

更新 1:

我还尝试将 xy 重塑为:

(9766, 1, 500)
(9766, 1, 500)

仍然没有运气。

LSTM 层期望输入形状为 [batch, timesteps, feature]。因此,对于形状 (9766, 1, 500),您有一个包含 500 个特征的时间步长。如果你有 500 个时间步,你的形状应该像 (9766, 500, 1).

这是一个示例架构:

x = tf.random.uniform((9766,500,1))
y = tf.random.uniform((9766,500,1))


model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(x.shape[1:])))
model.add(tf.keras.layers.LSTM(128, return_sequences=True, activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))

model.add(tf.keras.layers.LSTM(128, activation='relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.1))

model.add(tf.keras.layers.LSTM(32, activation = 'relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))

model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1 ,activation = 'sigmoid'))) # You can also remove timedistributed wrapper if you get better result. I supposed you need to have your output values between 0.0 and 1.0
model.compile(loss='mse',optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-3), metrics=['accuracy']) # Be careful about your chosen metric. 
model.summary()

如果您检查模型摘要,您会看到输入和输出形状与您预期的相同:

_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 lstm_15 (LSTM)              (None, 500, 128)          66560     
                                                                 
 dropout_13 (Dropout)        (None, 500, 128)          0         
                                                                 
 lstm_16 (LSTM)              (None, 500, 128)          131584    
                                                                 
 dropout_14 (Dropout)        (None, 500, 128)          0         
                                                                 
 lstm_17 (LSTM)              (None, 500, 32)           20608     
                                                                 
 dropout_15 (Dropout)        (None, 500, 32)           0         
                                                                 
 time_distributed_1 (TimeDis  (None, 500, 1)           33        
 tributed)                                                       
                                                                 
=================================================================
Total params: 218,785
Trainable params: 218,785
Non-trainable params: 0
_________________________________________________________________