keras顺序模型的正确输入和输出形状
proper input and output shape of a keras Sequential model
我正在尝试 运行 一个 Keras 顺序模型,但无法获得适合模型训练的正确形状。
我将 x
和 y
重塑为:
x = x.reshape(len(x), 500)
y = y.reshape(len(y), 500)
目前,输入形状和输出形状均为:
(9766, 500)
(9766, 500)
数据集分别由 9766 个输入和 9766 个输出组成。每个输入都是一个包含 500 个值的数组,每个输出也是一个包含 500 个值的数组。
所以这是一个单一的输入数组:
[0.99479668 0.99477965 0.99484778 0.99489887 0.99483926 0.99451565
0.99458378 0.99457526 0.99453268 0.99468597 0.99466042 0.99449862
0.99453268 0.99454971 0.99463487 0.99461784 0.99451565 0.99463487
0.99467745 0.99502661 0.99480519 0.99493294 0.99493294 0.99522248
0.99526506 0.99528209 0.99527358 0.99515435 0.99529913 0.99488184
0.99508623 0.99512881 0.99522248 0.99497552 0.9954439 0.99554609
0.99581861 0.99573345 0.9957079 0.99626144 0.99626144 0.99592932
0.99558867 0.99541835 0.99524803 0.99586119 0.99601448 0.99588674
0.99584416 0.99559719 0.995495 0.99520545 0.99552055 0.99510326
0.9951799 0.99560571 0.99561422 0.99541835 0.99586119 0.995759
0.9957079 0.99583564 0.9959208 0.99578454 0.99604854 0.99612519
0.99609112 0.99630402 0.9961337 0.99672983 0.99655099 0.99643176
0.99643176 0.99648286 0.99649138 0.99645731 0.99670428 0.99654247
0.99647435 0.99607409 0.99589525 0.99600596 0.99596338 0.99621035
0.99633809 0.99632106 0.99583564 0.99581009 0.99574196 0.9959719
0.99557164 0.99567383 0.99572493 0.9958697 0.99568235 0.9959208
0.99598893 0.99620183 0.99611667 0.99620183 0.9959719 0.9957079
0.99612519 0.99558867 0.99569938 0.99518842 0.99553758 0.99552055
0.99576751 0.99577603 0.99583564 0.99602299 0.99630402 0.99637215
0.99701937 0.99701086 0.99731744 0.99700234 0.99696828 0.99668725
0.99703641 0.99725782 0.99684054 0.99605706 0.99608261 0.99581861
0.9958697 0.99583564 0.99566532 0.99585267 0.99566532 0.99604003
0.99540984 0.99473707 0.995231 0.99441346 0.9942261 0.99397914
0.99367256 0.99409836 0.99415797 0.99420907 0.99398765 0.99356185
0.99382585 0.99428571 0.9945412 0.99444752 0.99436236 0.99404726
0.9938003 0.99424313 0.99483074 0.99474558 0.99457526 0.99457526
0.99465191 0.99466042 0.99467745 0.99448158 0.99454971 0.99479668
0.994703 0.99455823 0.99472855 0.99507771 0.99529913 0.99515435
0.99525655 0.99621886 0.99586119 0.99576751 0.9962359 0.99614222
0.99723228 0.99685757 0.99680647 0.99689163 0.99644028 0.99701937
0.99675538 0.99637215 0.99614222 0.99628699 0.9964488 0.99641473
0.99652544 0.99652544 0.99664467 0.99698531 0.99712157 0.99703641
0.99799872 0.99859485 0.99876517 0.99950607 0.99902065 0.99891846
0.99804982 0.99839898 0.99857782 0.99850117 0.99891846 0.99912284
0.99919097 0.99919949 0.99896956 0.99896104 0.99877369 0.99898659
0.99918246 0.99890994 0.9990462 0.99895252 0.99885033 0.99871407
0.99871407 0.99871407 0.99864594 0.99854375 0.9983564 0.9985693
0.99870556 0.99868001 0.9987822 0.99877369 0.99900362 0.99882478
0.99896956 0.99885885 0.99880775 0.99890994 0.99906323 0.99908026
0.9990462 0.99921652 0.99920801 0.99936129 0.99937833 0.99943794
0.99935278 0.99943794 0.99967639 0.99956568 0.99960826 0.99962529
0.99942942 0.99940387 0.9992591 0.99908878 0.99912284 0.99913988
0.99905472 0.99914839 0.99913136 0.99933575 0.99935278 0.99929317
0.99931871 0.99905472 0.99965084 0.99995742 1. 0.99962529
0.999472 0.99939536 0.99932723 0.99929317 0.99931871 0.99931871
0.99950607 0.99953162 0.99942942 0.99919097 0.99902917 0.99913988
0.99915691 0.9990462 0.9990973 0.99923355 0.99940387 0.99954865
0.99958271 0.99940387 0.99943794 0.99928465 0.9990973 0.99905472
0.99915691 0.99921652 0.99913988 0.99913136 0.99912284 0.9992591
0.99916542 0.99917394 0.99918246 0.99906323 0.99905472 0.99907175
0.99901214 0.9990462 0.99913988 0.9990462 0.9990462 0.99880775
0.99890994 0.99868852 0.99868852 0.99889291 0.99896956 0.99886736
0.99932723 0.99943794 0.99932723 0.99931871 0.99931871 0.99921652
0.99874814 0.99871407 0.99915691 0.99969342 0.99962529 0.99916542
0.99902917 0.99887588 0.99919097 0.99943794 0.99847562 0.9988333
0.99905472 0.99913988 0.99931871 0.99936129 0.99893549 0.99869704
0.99842453 0.99868001 0.99868852 0.9987822 0.9987311 0.99871407
0.99860336 0.99826272 0.99805834 0.99785395 0.99792208 0.99804982
0.99797317 0.99797317 0.99778582 0.99749627 0.99751331 0.99758143
0.99732595 0.99741111 0.99699383 0.99733447 0.99728337 0.99686608
0.99714712 0.9973515 0.99753885 0.99753034 0.99762402 0.99774324
0.99781989 0.99765808 0.99739408 0.9974026 0.99723228 0.99737705
0.99728337 0.99728337 0.99736002 0.99726634 0.99732595 0.99721524
0.99728337 0.99701937 0.99715563 0.99715563 0.99744518 0.99753034
0.99747073 0.99765808 0.9978284 0.99726634 0.99724931 0.99776879
0.99746221 0.9976666 0.9976666 0.99744518 0.99734298 0.99833085
0.99866298 0.99800724 0.99714712 0.99648286 0.99588674 0.99598041
0.99563125 0.99595486 0.99626144 0.99601448 0.99456674 0.9947541
0.99499255 0.99483926 0.9950181 0.99497552 0.99484778 0.99424313
0.99416649 0.99416649 0.9942772 0.99288908 0.99266766 0.99293166
0.99248031 0.99312753 0.99269321 0.99307643 0.99286353 0.99319566
0.99346817 0.99337449 0.99322972 0.99302534 0.99322121 0.99307643
0.99295721 0.99344262 0.99262508 0.99259953 0.99246327 0.99254844
0.99265063 0.99288908 0.99288908 0.9930594 0.9933234 0.99340004
0.99320417 0.99331488 0.99319566 0.99335746 0.99322121 0.99271876
0.99271024 0.99270172 0.99259102 0.99308495 0.99331488 0.9930083
0.99285501 0.99289759 0.99276134 0.99259102 0.99266766 0.99221631
0.99216521 0.99225889 0.99227592 0.99196934 0.99162018 0.99147541
0.99134767 0.99159463 0.99152651 0.99166276 0.99169683 0.99168831
0.99175644 0.99178199 0.99161167 0.99165425 0.99170534 0.9915776
0.9915776 0.99144135 0.99169683 0.99170534 0.99144986 0.99170534
0.99187567 0.99192676 0.99183308 0.99177347 0.99173941 0.99176496
0.99170534 0.9917905 0.99178199 0.99144986 0.99147541 0.99142431
0.99149244 0.99139877]
这是一个输出数组:
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99449862
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.99731744 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99356185
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99686608
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.99866298 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.99134767 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. ]
这是我正在尝试训练数据的模型(很可能架构不佳):
model = Sequential()
model.add(LSTM(128, input_shape=(x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(128, input_shape=(x.shape[1:])))
model.add(Dropout(0.1))
model.add(LSTM(32, input_shape=(x.shape[1:]) ,activation = 'relu'))
model.add(Dropout(0.2))
model.add (Dense(1 ,activation = 'sigmoid'))
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-3)
model.compile(loss='mse',optimizer=opt, metrics=['accuracy'])
model.fit(x,y,epochs=20,validation_split=0.20)
我希望模型训练的方式是查看输入并生成一个包含 500 个值的数组,如上所示的输出数组。
但无论我尝试什么形状,我都会收到如下错误:
ValueError: Input 0 of layer "lstm" is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 500)
这里的正确形状是什么形状,我在模型架构上做错了什么?
更新 1:
我还尝试将 x
和 y
重塑为:
(9766, 1, 500)
(9766, 1, 500)
仍然没有运气。
LSTM 层期望输入形状为 [batch, timesteps, feature]。因此,对于形状 (9766, 1, 500),您有一个包含 500 个特征的时间步长。如果你有 500 个时间步,你的形状应该像 (9766, 500, 1)
.
这是一个示例架构:
x = tf.random.uniform((9766,500,1))
y = tf.random.uniform((9766,500,1))
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(x.shape[1:])))
model.add(tf.keras.layers.LSTM(128, return_sequences=True, activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.LSTM(128, activation='relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.1))
model.add(tf.keras.layers.LSTM(32, activation = 'relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1 ,activation = 'sigmoid'))) # You can also remove timedistributed wrapper if you get better result. I supposed you need to have your output values between 0.0 and 1.0
model.compile(loss='mse',optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-3), metrics=['accuracy']) # Be careful about your chosen metric.
model.summary()
如果您检查模型摘要,您会看到输入和输出形状与您预期的相同:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_15 (LSTM) (None, 500, 128) 66560
dropout_13 (Dropout) (None, 500, 128) 0
lstm_16 (LSTM) (None, 500, 128) 131584
dropout_14 (Dropout) (None, 500, 128) 0
lstm_17 (LSTM) (None, 500, 32) 20608
dropout_15 (Dropout) (None, 500, 32) 0
time_distributed_1 (TimeDis (None, 500, 1) 33
tributed)
=================================================================
Total params: 218,785
Trainable params: 218,785
Non-trainable params: 0
_________________________________________________________________
我正在尝试 运行 一个 Keras 顺序模型,但无法获得适合模型训练的正确形状。
我将 x
和 y
重塑为:
x = x.reshape(len(x), 500)
y = y.reshape(len(y), 500)
目前,输入形状和输出形状均为:
(9766, 500)
(9766, 500)
数据集分别由 9766 个输入和 9766 个输出组成。每个输入都是一个包含 500 个值的数组,每个输出也是一个包含 500 个值的数组。
所以这是一个单一的输入数组:
[0.99479668 0.99477965 0.99484778 0.99489887 0.99483926 0.99451565
0.99458378 0.99457526 0.99453268 0.99468597 0.99466042 0.99449862
0.99453268 0.99454971 0.99463487 0.99461784 0.99451565 0.99463487
0.99467745 0.99502661 0.99480519 0.99493294 0.99493294 0.99522248
0.99526506 0.99528209 0.99527358 0.99515435 0.99529913 0.99488184
0.99508623 0.99512881 0.99522248 0.99497552 0.9954439 0.99554609
0.99581861 0.99573345 0.9957079 0.99626144 0.99626144 0.99592932
0.99558867 0.99541835 0.99524803 0.99586119 0.99601448 0.99588674
0.99584416 0.99559719 0.995495 0.99520545 0.99552055 0.99510326
0.9951799 0.99560571 0.99561422 0.99541835 0.99586119 0.995759
0.9957079 0.99583564 0.9959208 0.99578454 0.99604854 0.99612519
0.99609112 0.99630402 0.9961337 0.99672983 0.99655099 0.99643176
0.99643176 0.99648286 0.99649138 0.99645731 0.99670428 0.99654247
0.99647435 0.99607409 0.99589525 0.99600596 0.99596338 0.99621035
0.99633809 0.99632106 0.99583564 0.99581009 0.99574196 0.9959719
0.99557164 0.99567383 0.99572493 0.9958697 0.99568235 0.9959208
0.99598893 0.99620183 0.99611667 0.99620183 0.9959719 0.9957079
0.99612519 0.99558867 0.99569938 0.99518842 0.99553758 0.99552055
0.99576751 0.99577603 0.99583564 0.99602299 0.99630402 0.99637215
0.99701937 0.99701086 0.99731744 0.99700234 0.99696828 0.99668725
0.99703641 0.99725782 0.99684054 0.99605706 0.99608261 0.99581861
0.9958697 0.99583564 0.99566532 0.99585267 0.99566532 0.99604003
0.99540984 0.99473707 0.995231 0.99441346 0.9942261 0.99397914
0.99367256 0.99409836 0.99415797 0.99420907 0.99398765 0.99356185
0.99382585 0.99428571 0.9945412 0.99444752 0.99436236 0.99404726
0.9938003 0.99424313 0.99483074 0.99474558 0.99457526 0.99457526
0.99465191 0.99466042 0.99467745 0.99448158 0.99454971 0.99479668
0.994703 0.99455823 0.99472855 0.99507771 0.99529913 0.99515435
0.99525655 0.99621886 0.99586119 0.99576751 0.9962359 0.99614222
0.99723228 0.99685757 0.99680647 0.99689163 0.99644028 0.99701937
0.99675538 0.99637215 0.99614222 0.99628699 0.9964488 0.99641473
0.99652544 0.99652544 0.99664467 0.99698531 0.99712157 0.99703641
0.99799872 0.99859485 0.99876517 0.99950607 0.99902065 0.99891846
0.99804982 0.99839898 0.99857782 0.99850117 0.99891846 0.99912284
0.99919097 0.99919949 0.99896956 0.99896104 0.99877369 0.99898659
0.99918246 0.99890994 0.9990462 0.99895252 0.99885033 0.99871407
0.99871407 0.99871407 0.99864594 0.99854375 0.9983564 0.9985693
0.99870556 0.99868001 0.9987822 0.99877369 0.99900362 0.99882478
0.99896956 0.99885885 0.99880775 0.99890994 0.99906323 0.99908026
0.9990462 0.99921652 0.99920801 0.99936129 0.99937833 0.99943794
0.99935278 0.99943794 0.99967639 0.99956568 0.99960826 0.99962529
0.99942942 0.99940387 0.9992591 0.99908878 0.99912284 0.99913988
0.99905472 0.99914839 0.99913136 0.99933575 0.99935278 0.99929317
0.99931871 0.99905472 0.99965084 0.99995742 1. 0.99962529
0.999472 0.99939536 0.99932723 0.99929317 0.99931871 0.99931871
0.99950607 0.99953162 0.99942942 0.99919097 0.99902917 0.99913988
0.99915691 0.9990462 0.9990973 0.99923355 0.99940387 0.99954865
0.99958271 0.99940387 0.99943794 0.99928465 0.9990973 0.99905472
0.99915691 0.99921652 0.99913988 0.99913136 0.99912284 0.9992591
0.99916542 0.99917394 0.99918246 0.99906323 0.99905472 0.99907175
0.99901214 0.9990462 0.99913988 0.9990462 0.9990462 0.99880775
0.99890994 0.99868852 0.99868852 0.99889291 0.99896956 0.99886736
0.99932723 0.99943794 0.99932723 0.99931871 0.99931871 0.99921652
0.99874814 0.99871407 0.99915691 0.99969342 0.99962529 0.99916542
0.99902917 0.99887588 0.99919097 0.99943794 0.99847562 0.9988333
0.99905472 0.99913988 0.99931871 0.99936129 0.99893549 0.99869704
0.99842453 0.99868001 0.99868852 0.9987822 0.9987311 0.99871407
0.99860336 0.99826272 0.99805834 0.99785395 0.99792208 0.99804982
0.99797317 0.99797317 0.99778582 0.99749627 0.99751331 0.99758143
0.99732595 0.99741111 0.99699383 0.99733447 0.99728337 0.99686608
0.99714712 0.9973515 0.99753885 0.99753034 0.99762402 0.99774324
0.99781989 0.99765808 0.99739408 0.9974026 0.99723228 0.99737705
0.99728337 0.99728337 0.99736002 0.99726634 0.99732595 0.99721524
0.99728337 0.99701937 0.99715563 0.99715563 0.99744518 0.99753034
0.99747073 0.99765808 0.9978284 0.99726634 0.99724931 0.99776879
0.99746221 0.9976666 0.9976666 0.99744518 0.99734298 0.99833085
0.99866298 0.99800724 0.99714712 0.99648286 0.99588674 0.99598041
0.99563125 0.99595486 0.99626144 0.99601448 0.99456674 0.9947541
0.99499255 0.99483926 0.9950181 0.99497552 0.99484778 0.99424313
0.99416649 0.99416649 0.9942772 0.99288908 0.99266766 0.99293166
0.99248031 0.99312753 0.99269321 0.99307643 0.99286353 0.99319566
0.99346817 0.99337449 0.99322972 0.99302534 0.99322121 0.99307643
0.99295721 0.99344262 0.99262508 0.99259953 0.99246327 0.99254844
0.99265063 0.99288908 0.99288908 0.9930594 0.9933234 0.99340004
0.99320417 0.99331488 0.99319566 0.99335746 0.99322121 0.99271876
0.99271024 0.99270172 0.99259102 0.99308495 0.99331488 0.9930083
0.99285501 0.99289759 0.99276134 0.99259102 0.99266766 0.99221631
0.99216521 0.99225889 0.99227592 0.99196934 0.99162018 0.99147541
0.99134767 0.99159463 0.99152651 0.99166276 0.99169683 0.99168831
0.99175644 0.99178199 0.99161167 0.99165425 0.99170534 0.9915776
0.9915776 0.99144135 0.99169683 0.99170534 0.99144986 0.99170534
0.99187567 0.99192676 0.99183308 0.99177347 0.99173941 0.99176496
0.99170534 0.9917905 0.99178199 0.99144986 0.99147541 0.99142431
0.99149244 0.99139877]
这是一个输出数组:
[0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99449862
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0.99731744 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99356185
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 1. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.99686608
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.99866298 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0.99134767 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0.
0. 0. ]
这是我正在尝试训练数据的模型(很可能架构不佳):
model = Sequential()
model.add(LSTM(128, input_shape=(x.shape[1:]), return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(128, input_shape=(x.shape[1:])))
model.add(Dropout(0.1))
model.add(LSTM(32, input_shape=(x.shape[1:]) ,activation = 'relu'))
model.add(Dropout(0.2))
model.add (Dense(1 ,activation = 'sigmoid'))
opt = tf.keras.optimizers.Adam(lr=0.001, decay=1e-3)
model.compile(loss='mse',optimizer=opt, metrics=['accuracy'])
model.fit(x,y,epochs=20,validation_split=0.20)
我希望模型训练的方式是查看输入并生成一个包含 500 个值的数组,如上所示的输出数组。 但无论我尝试什么形状,我都会收到如下错误:
ValueError: Input 0 of layer "lstm" is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: (None, 500)
这里的正确形状是什么形状,我在模型架构上做错了什么?
更新 1:
我还尝试将 x
和 y
重塑为:
(9766, 1, 500)
(9766, 1, 500)
仍然没有运气。
LSTM 层期望输入形状为 [batch, timesteps, feature]。因此,对于形状 (9766, 1, 500),您有一个包含 500 个特征的时间步长。如果你有 500 个时间步,你的形状应该像 (9766, 500, 1)
.
这是一个示例架构:
x = tf.random.uniform((9766,500,1))
y = tf.random.uniform((9766,500,1))
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=(x.shape[1:])))
model.add(tf.keras.layers.LSTM(128, return_sequences=True, activation='relu'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.LSTM(128, activation='relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.1))
model.add(tf.keras.layers.LSTM(32, activation = 'relu', return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(1 ,activation = 'sigmoid'))) # You can also remove timedistributed wrapper if you get better result. I supposed you need to have your output values between 0.0 and 1.0
model.compile(loss='mse',optimizer=tf.keras.optimizers.Adam(lr=0.001, decay=1e-3), metrics=['accuracy']) # Be careful about your chosen metric.
model.summary()
如果您检查模型摘要,您会看到输入和输出形状与您预期的相同:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm_15 (LSTM) (None, 500, 128) 66560
dropout_13 (Dropout) (None, 500, 128) 0
lstm_16 (LSTM) (None, 500, 128) 131584
dropout_14 (Dropout) (None, 500, 128) 0
lstm_17 (LSTM) (None, 500, 32) 20608
dropout_15 (Dropout) (None, 500, 32) 0
time_distributed_1 (TimeDis (None, 500, 1) 33
tributed)
=================================================================
Total params: 218,785
Trainable params: 218,785
Non-trainable params: 0
_________________________________________________________________