基于 LSTM 的 EEG 信号分类架构基于 "Channel LSTM"

LSTM-based architecture for EEG signal Classification based-on "Channel LSTM"

我在 python 3.6 中使用 keras 和 tensorflow 时遇到了一个多 class 分类问题。基于本文中提到的“stacked LSTM 层 (a)”,我的 classification 具有高精度的良好实现:Deep Learning Human Mind for Automated Visual Classification.

像这样的事情:

model.add(LSTM(256,input_shape=(32, 15360), return_sequences=True))
model.add(LSTM(128), return_sequences=True)
model.add(LSTM(64), return_sequences=False)

model.add(Dense(6, activation='softmax'))

设 32 是 EEG 通道数,15360 是 96 秒记录中 160 Hz 的信号长度

我想实现上面文章中提到的“Channel LSTM and Common LSTM (b)”策略,但我不知道我应该如何通过这个来制作我的模型新战略。

请帮助我。谢谢

首先,您在使用 Common LSTM 实现编码器时遇到问题,LSTM layer of keras 默认采用形状 (batch, timesteps, channel) 的输入,所以如果你设置你的 input_shape=(32, 15360) 然后模型将读作 timesteps=32channel=15360 这与你想要的相反。

因为第一层编码器使用Common LSTM描述为:

At each time step t, the first layer takes the input s(·, t)(in this sense, “common” means that all EEG channels are initially fed8 into the same LSTM layer)

因此使用 Common LSTM 编码器的正确实现是:

import tensorflow as tf
from tensorflow.keras import layers, models

timesteps = 15360
channels_num = 32

model = models.Sequential()
model.add(layers.LSTM(256,input_shape=(timesteps, channels_num), return_sequences=True))
model.add(layers.LSTM(128, return_sequences=True))
model.add(layers.LSTM(64, return_sequences=False))
model.add(layers.Dense(6, activation='softmax'))

model.summary()

哪些输出(PS:您可以总结您的原始实现,您会看到 Total params 更大):

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
lstm (LSTM)                  (None, 15360, 256)        295936
_________________________________________________________________
lstm_1 (LSTM)                (None, 15360, 128)        197120
_________________________________________________________________
lstm_2 (LSTM)                (None, 64)                49408
_________________________________________________________________
dense (Dense)                (None, 6)                 390
=================================================================
Total params: 542,854
Trainable params: 542,854
Non-trainable params: 0
_________________________________________________________________

其次,因为编码器使用Channel LSTM和Common LSTM描述为:

The first encoding layer consists of several LSTMs, each connected to only one input channel: for example, the first LSTM processes input datas(1,·), the second LSTM processess(2,·), and so on. In this way, the output of each “channel LSTM”is a summary of a single channel’s data. The second encoding layer then performs inter-channel analysis, by receiving as input the concatenated output vectors of all channel LSTMs. As above, the output of the deepest LSTM at the last time step is used as the encoder’s output vector.

由于第一层的每个 LSTM 只处理一个通道,所以我们需要 LSTM 的数量等于第一层的通道数量,下面的代码展示了如何使用 构建一个编码器通道 LSTM 和普通 LSTM:

import tensorflow as tf
from tensorflow.keras import layers, models

timesteps = 15360
channels_num = 32

first_layer_inputs = []
second_layer_inputs = []
for i in range(channels_num):
    l_input = layers.Input(shape=(timesteps, 1))
    first_layer_inputs.append(l_input)
    l_output = layers.LSTM(1, return_sequences=True)(l_input)
    second_layer_inputs.append(l_output)

x = layers.Concatenate()(second_layer_inputs)
x = layers.LSTM(128, return_sequences=True)(x)
x = layers.LSTM(64, return_sequences=False)(x)
outputs = layers.Dense(6, activation='softmax')(x)

model = models.Model(inputs=first_layer_inputs, outputs=outputs)

model.summary()

输出:

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_3 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_5 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_6 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_9 (InputLayer)            [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_10 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_11 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_12 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_13 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_14 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_15 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_16 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_17 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_18 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_19 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_20 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_21 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_22 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_23 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_24 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_25 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_26 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_27 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_28 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_29 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_30 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_31 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
input_32 (InputLayer)           [(None, 15360, 1)]   0
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 15360, 1)     12          input_1[0][0]
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, 15360, 1)     12          input_2[0][0]
__________________________________________________________________________________________________
lstm_2 (LSTM)                   (None, 15360, 1)     12          input_3[0][0]
__________________________________________________________________________________________________
lstm_3 (LSTM)                   (None, 15360, 1)     12          input_4[0][0]
__________________________________________________________________________________________________
lstm_4 (LSTM)                   (None, 15360, 1)     12          input_5[0][0]
__________________________________________________________________________________________________
lstm_5 (LSTM)                   (None, 15360, 1)     12          input_6[0][0]
__________________________________________________________________________________________________
lstm_6 (LSTM)                   (None, 15360, 1)     12          input_7[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM)                   (None, 15360, 1)     12          input_8[0][0]
__________________________________________________________________________________________________
lstm_8 (LSTM)                   (None, 15360, 1)     12          input_9[0][0]
__________________________________________________________________________________________________
lstm_9 (LSTM)                   (None, 15360, 1)     12          input_10[0][0]
__________________________________________________________________________________________________
lstm_10 (LSTM)                  (None, 15360, 1)     12          input_11[0][0]
__________________________________________________________________________________________________
lstm_11 (LSTM)                  (None, 15360, 1)     12          input_12[0][0]
__________________________________________________________________________________________________
lstm_12 (LSTM)                  (None, 15360, 1)     12          input_13[0][0]
__________________________________________________________________________________________________
lstm_13 (LSTM)                  (None, 15360, 1)     12          input_14[0][0]
__________________________________________________________________________________________________
lstm_14 (LSTM)                  (None, 15360, 1)     12          input_15[0][0]
__________________________________________________________________________________________________
lstm_15 (LSTM)                  (None, 15360, 1)     12          input_16[0][0]
__________________________________________________________________________________________________
lstm_16 (LSTM)                  (None, 15360, 1)     12          input_17[0][0]
__________________________________________________________________________________________________
lstm_17 (LSTM)                  (None, 15360, 1)     12          input_18[0][0]
__________________________________________________________________________________________________
lstm_18 (LSTM)                  (None, 15360, 1)     12          input_19[0][0]
__________________________________________________________________________________________________
lstm_19 (LSTM)                  (None, 15360, 1)     12          input_20[0][0]
__________________________________________________________________________________________________
lstm_20 (LSTM)                  (None, 15360, 1)     12          input_21[0][0]
__________________________________________________________________________________________________
lstm_21 (LSTM)                  (None, 15360, 1)     12          input_22[0][0]
__________________________________________________________________________________________________
lstm_22 (LSTM)                  (None, 15360, 1)     12          input_23[0][0]
__________________________________________________________________________________________________
lstm_23 (LSTM)                  (None, 15360, 1)     12          input_24[0][0]
__________________________________________________________________________________________________
lstm_24 (LSTM)                  (None, 15360, 1)     12          input_25[0][0]
__________________________________________________________________________________________________
lstm_25 (LSTM)                  (None, 15360, 1)     12          input_26[0][0]
__________________________________________________________________________________________________
lstm_26 (LSTM)                  (None, 15360, 1)     12          input_27[0][0]
__________________________________________________________________________________________________
lstm_27 (LSTM)                  (None, 15360, 1)     12          input_28[0][0]
__________________________________________________________________________________________________
lstm_28 (LSTM)                  (None, 15360, 1)     12          input_29[0][0]
__________________________________________________________________________________________________
lstm_29 (LSTM)                  (None, 15360, 1)     12          input_30[0][0]
__________________________________________________________________________________________________
lstm_30 (LSTM)                  (None, 15360, 1)     12          input_31[0][0]
__________________________________________________________________________________________________
lstm_31 (LSTM)                  (None, 15360, 1)     12          input_32[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 15360, 32)    0           lstm[0][0]
                                                                 lstm_1[0][0]
                                                                 lstm_2[0][0]
                                                                 lstm_3[0][0]
                                                                 lstm_4[0][0]
                                                                 lstm_5[0][0]
                                                                 lstm_6[0][0]
                                                                 lstm_7[0][0]
                                                                 lstm_8[0][0]
                                                                 lstm_9[0][0]
                                                                 lstm_10[0][0]
                                                                 lstm_11[0][0]
                                                                 lstm_12[0][0]
                                                                 lstm_13[0][0]
                                                                 lstm_14[0][0]
                                                                 lstm_15[0][0]
                                                                 lstm_16[0][0]
                                                                 lstm_17[0][0]
                                                                 lstm_18[0][0]
                                                                 lstm_19[0][0]
                                                                 lstm_20[0][0]
                                                                 lstm_21[0][0]
                                                                 lstm_22[0][0]
                                                                 lstm_23[0][0]
                                                                 lstm_24[0][0]
                                                                 lstm_25[0][0]
                                                                 lstm_26[0][0]
                                                                 lstm_27[0][0]
                                                                 lstm_28[0][0]
                                                                 lstm_29[0][0]
                                                                 lstm_30[0][0]
                                                                 lstm_31[0][0]
__________________________________________________________________________________________________
lstm_32 (LSTM)                  (None, 15360, 128)   82432       concatenate[0][0]
__________________________________________________________________________________________________
lstm_33 (LSTM)                  (None, 64)           49408       lstm_32[0][0]
__________________________________________________________________________________________________
dense (Dense)                   (None, 6)            390         lstm_33[0][0]
==================================================================================================
Total params: 132,614
Trainable params: 132,614
Non-trainable params: 0
__________________________________________________________________________________________________

现在因为模型需要形状为 (channel, batch, timesteps, 1) 的输入,所以我们必须在输入模型之前对数据集的轴重新排序,以下示例代码向您展示如何将轴从 (batch, timesteps, channel) 重新排序为(channel, batch, timesteps, 1):

import numpy as np

batch_size = 64
timesteps = 15360
channels_num = 32

x = np.random.rand(batch_size, timesteps, channels_num)
print(x.shape)
x = np.moveaxis(x, -1, 0)[..., np.newaxis]
print(x.shape)
x = [i for i in x]
print(x[0].shape)

输出:

(64, 15360, 32)
(32, 64, 15360, 1)
(64, 15360, 1)