lstm/gru 的输入数据准备

Input data preparation for lstm/gru

我在理解如何转换我的数据以馈送到网络时遇到问题(我认为 lstm 网络有帮助,因为我的数据主要是时间序列类型并且也有一些时间信息所以..)。

这是数据格式 前 6 列代表一秒钟的数据 (larger_corr, shorter_corr,noiseratio,x,y,z) 然后是相应的输出特征,然后是下一秒数据。

但是为了准备训练数据,我怎样才能发送 6 列数据,然后再发送 6 列 columns.All 这些列的长度为 40。

不知道我表达的够不够清楚

如果您需要任何其他信息,请告诉我。

您可以尝试按如下方式准备您的数据,但请注意,我只使用了 12 列以确保可读性:

import pandas as pd
import numpy as np
import tensorflow as tf
import tabulate
np.random.seed(0)

df = pd.DataFrame({
    'larger_corr' : np.random.randn(25),
    'shorter_corr' : np.random.randn(25),
    'noiseratio' : np.random.randn(25),
    'x' : np.random.randn(25),
    'y' : np.random.randn(25),
    'z' : np.random.randn(25),
    'output' : np.random.randint(0,2,25),
    'larger_corr.1' : np.random.randn(25),
    'shorter_corr.1' : np.random.randn(25),
    'noiseratio.1' : np.random.randn(25),
    'x.1' : np.random.randn(25),
    'y.1' : np.random.randn(25),
    'z.1' : np.random.randn(25),
    'output.1' : np.random.randint(0,2,25)
})

print(df.to_markdown())
y1, y2 = df.pop('output').to_numpy(), df.pop('output.1').to_numpy()
data = df.to_numpy()
x1, x2 = np.array_split(data, 2, axis=1)
x1 = np.expand_dims(x1, axis=1) # add timestep dimension
x2 = np.expand_dims(x2, axis=1) # add timestep dimension
X = np.concatenate([x1, x2])
Y = np.concatenate([y1, y1])
print('Shape of X -->', X.shape, 'Shape of labels -->', Y.shape)
|    |   larger_corr |   shorter_corr |   noiseratio |          x |         y |          z |   output |   larger_corr.1 |   shorter_corr.1 |   noiseratio.1 |        x.1 |        y.1 |         z.1 |   output.1 |
|---:|--------------:|---------------:|-------------:|-----------:|----------:|-----------:|---------:|----------------:|-----------------:|---------------:|-----------:|-----------:|------------:|-----------:|
|  0 |      1.76405  |     -1.45437   |   -0.895467  | -0.68481   |  1.88315  | -0.149635  |        1 |       0.438871  |       -0.244179  |     -0.891895  | -0.617166  |  1.14367   | -0.936916   |          0 |
|  1 |      0.400157 |      0.0457585 |    0.386902  | -0.870797  | -1.34776  | -0.435154  |        1 |       0.63826   |        0.475261  |      0.570081  | -1.77556   | -0.188056  | -1.97935    |          0 |
|  2 |      0.978738 |     -0.187184  |   -0.510805  | -0.57885   | -1.27048  |  1.84926   |        0 |       2.01584   |       -0.714216  |      2.66323   | -1.11821   |  1.24678   |  0.445384   |          0 |
|  3 |      2.24089  |      1.53278   |   -1.18063   | -0.311553  |  0.969397 |  0.672295  |        0 |      -0.243653  |       -1.18694   |      0.410289  | -1.60639   | -0.253884  | -0.195333   |          1 |
|  4 |      1.86756  |      1.46936   |   -0.0281822 |  0.0561653 | -1.17312  |  0.407462  |        1 |       1.53384   |        0.608891  |      0.485652  | -0.814676  | -0.870176  | -0.202716   |          1 |
|  5 |     -0.977278 |      0.154947  |    0.428332  | -1.16515   |  1.94362  | -0.769916  |        1 |       0.76475   |        0.504223  |      1.31153   |  0.321281  |  0.0196537 |  0.219389   |          0 |
|  6 |      0.950088 |      0.378163  |    0.0665172 |  0.900826  | -0.413619 |  0.539249  |        0 |      -2.45668   |       -0.513996  |     -0.235649  | -0.12393   | -1.11437   | -1.03016    |          0 |
|  7 |     -0.151357 |     -0.887786  |    0.302472  |  0.465662  | -0.747455 | -0.674333  |        1 |      -1.70365   |        0.818475  |     -1.48018   |  0.0221213 |  0.607842  | -0.929744   |          0 |
|  8 |     -0.103219 |     -1.9808    |   -0.634322  | -1.53624   |  1.92294  |  0.0318306 |        1 |       0.420153  |        1.1566    |     -0.0214848 | -0.321287  |  0.457237  | -2.55857    |          1 |
|  9 |      0.410599 |     -0.347912  |   -0.362741  |  1.48825   |  1.48051  | -0.635846  |        1 |      -0.298149  |       -0.803689  |      1.05279   |  0.692618  |  0.875539  |  1.6495     |          0 |
| 10 |      0.144044 |      0.156349  |   -0.67246   |  1.89589   |  1.86756  |  0.676433  |        1 |       0.263602  |       -0.551562  |     -0.117402  | -0.353524  |  0.346481  |  0.611738   |          0 |
| 11 |      1.45427  |      1.23029   |   -0.359553  |  1.17878   |  0.906045 |  0.576591  |        1 |       0.731266  |       -0.332414  |      1.82851   |  0.81229   | -0.454874  | -1.05194    |          1 |
| 12 |      0.761038 |      1.20238   |   -0.813146  | -0.179925  | -0.861226 | -0.208299  |        1 |       0.22807   |        1.84452   |     -0.0166771 | -1.14179   |  0.198095  | -0.754946   |          0 |
| 13 |      0.121675 |     -0.387327  |   -1.72628   | -1.07075   |  1.91006  |  0.396007  |        0 |      -2.02852   |       -0.422776  |      1.87011   | -0.287549  |  0.391408  |  0.623188   |          1 |
| 14 |      0.443863 |     -0.302303  |    0.177426  |  1.05445   | -0.268003 | -1.09306   |        0 |       0.96619   |        0.487659  |     -0.380307  |  1.31554   | -3.17786   |  0.00470758 |          0 |
| 15 |      0.333674 |     -1.04855   |   -0.401781  | -0.403177  |  0.802456 | -1.49126   |        1 |      -0.186922  |       -0.375828  |      0.428698  |  0.685781  | -0.956575  | -0.899891   |          0 |
| 16 |      1.49408  |     -1.42002   |   -1.6302    |  1.22245   |  0.947252 |  0.439392  |        0 |      -0.472325  |        0.227851  |      0.361896  |  0.524599  | -0.0312749 |  0.129242   |          1 |
| 17 |     -0.205158 |     -1.70627   |    0.462782  |  0.208275  | -0.15501  |  0.166673  |        1 |       1.93666   |        0.703789  |      0.467568  | -0.793387  |  1.03272   |  0.979693   |          1 |
| 18 |      0.313068 |      1.95078   |   -0.907298  |  0.976639  |  0.614079 |  0.635031  |        0 |       1.47734   |       -0.7978    |     -1.51803   | -0.237881  | -1.21562   |  0.328375   |          0 |
| 19 |     -0.854096 |     -0.509652  |    0.0519454 |  0.356366  |  0.922207 |  2.38314   |        0 |      -0.0848901 |       -0.6759    |     -1.89304   |  0.569498  | -0.318678  |  0.487074   |          0 |
| 20 |     -2.55299  |     -0.438074  |    0.729091  |  0.706573  |  0.376426 |  0.944479  |        1 |       0.427697  |       -0.922546  |     -0.785087  | -1.51061   |  1.49513   |  0.144842   |          1 |
| 21 |      0.653619 |     -1.2528    |    0.128983  |  0.0105    | -1.0994   | -0.912822  |        1 |      -0.30428   |       -0.448586  |     -1.60529   | -1.56505   | -0.130251  | -0.0856099  |          1 |
| 22 |      0.864436 |      0.77749   |    1.1394    |  1.78587   |  0.298238 |  1.11702   |        1 |       0.204625  |        0.181979  |      1.43184   | -3.05123   | -1.20289   |  0.71054    |          1 |
| 23 |     -0.742165 |     -1.6139    |   -1.23483   |  0.126912  |  1.32639  | -1.31591   |        1 |      -0.0833382 |       -0.220084  |     -1.94219   |  1.55966   |  0.199565  |  0.93096    |          0 |
| 24 |      2.26975  |     -0.21274   |    0.402342  |  0.401989  | -0.694568 | -0.461585  |        1 |       1.82893   |        0.0249562 |      1.13995   | -2.63101   |  0.393166  |  0.875074   |          0 |
Shape of X --> (50, 1, 6) Shape of labels --> (50,)

预处理数据后,您可以创建这样的 LSTM 模型,其中维度 timesteps 代表 1 秒:

timesteps, features = X.shape[1], X.shape[2]
input = tf.keras.layers.Input(shape=(timesteps, features))
x = tf.keras.layers.LSTM(32, return_sequences=False)(input)
output = tf.keras.layers.Dense(1, activation='sigmoid')(x)
model = tf.keras.Model(input, output)
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy())
print(model.summary())
model.fit(X, Y, batch_size=10, epochs=5)
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_16 (InputLayer)       [(None, 1, 6)]            0         
                                                                 
 lstm_1 (LSTM)               (None, 32)                4992      
                                                                 
 dense_21 (Dense)            (None, 1)                 33        
                                                                 
=================================================================
Total params: 5,025
Trainable params: 5,025
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/5
5/5 [==============================] - 2s 4ms/step - loss: 0.6914
Epoch 2/5
5/5 [==============================] - 0s 3ms/step - loss: 0.6852
Epoch 3/5
5/5 [==============================] - 0s 3ms/step - loss: 0.6806
Epoch 4/5
5/5 [==============================] - 0s 4ms/step - loss: 0.6758
Epoch 5/5
5/5 [==============================] - 0s 4ms/step - loss: 0.6705
<keras.callbacks.History at 0x7f90ca6c6d90>

您还可以在使用 MinMaxScaler or StandardScaler 将数据输入模型之前缩放/规范化数据,但我会把它留给您。