仅在 Keras 中对某些输入进行 BatchNormalization
BatchNormalization for some inputs only in Keras
我的 LSTM 网络有 5 个输入。第一个输入的典型值是从 1000 到 3000。其余输入的值是从 -1 到 1。
我想插入 BatchNormalization 作为第一层。但是输入 2-5 已经在 -1 和 1 之间,并且第一个输入比第一个输入大得多。那就是我只想对第一个输入应用批量归一化,并按原样保留输入 2-5。然后应该将第一个(归一化)输入和 2-5 个输入传递给 LSTM 层。
+----+ +---+
1 -->| BN |-->| |
+----+ | L |
2 ----------->| S |
3 ----------->| T |
4 ----------->| M |
5 ----------->| |
+---+
如何在 Keras 中完成?
我认为我可以为第一个输入创建带有 BatchNormalization 裸层的模型,然后将它与其余层连接起来。但我不确定,也不知道具体怎么做。
考虑到您的训练数据形状为 (batch,timeSteps,5)
,也许您应该像这样简单地更改输入:
maxVal = abs(X_train[:,:,0].max())
minVal = abs(X_train[:,:,0].min())
maxVal = max(maxVal,minVal)
X_train[:,:,0] = X_train[:,:,0] / maxVal
试试下面的定义:
from keras.layers.merge import concatenate
input_tensor = Input(shape=(timesteps, 5))
# now let's split tensors
split_1 = Lambda(lambda x: x[:, :, :1])(input_tensor)
split_2 = Lambda(lambda x: x[:, :, 1:])(input_tensor)
split_1 = BatchNormalization()(split_1)
# now let's concatenate them again
follow = concatenate([split_1, split_2])
但正如 Daniel 在他的评论中提到的那样 - 最好对数据进行规范化以处理此类不一致 - 使用 BatchNormalization
可能会导致性能下降。
我的 LSTM 网络有 5 个输入。第一个输入的典型值是从 1000 到 3000。其余输入的值是从 -1 到 1。
我想插入 BatchNormalization 作为第一层。但是输入 2-5 已经在 -1 和 1 之间,并且第一个输入比第一个输入大得多。那就是我只想对第一个输入应用批量归一化,并按原样保留输入 2-5。然后应该将第一个(归一化)输入和 2-5 个输入传递给 LSTM 层。
+----+ +---+
1 -->| BN |-->| |
+----+ | L |
2 ----------->| S |
3 ----------->| T |
4 ----------->| M |
5 ----------->| |
+---+
如何在 Keras 中完成?
我认为我可以为第一个输入创建带有 BatchNormalization 裸层的模型,然后将它与其余层连接起来。但我不确定,也不知道具体怎么做。
考虑到您的训练数据形状为 (batch,timeSteps,5)
,也许您应该像这样简单地更改输入:
maxVal = abs(X_train[:,:,0].max())
minVal = abs(X_train[:,:,0].min())
maxVal = max(maxVal,minVal)
X_train[:,:,0] = X_train[:,:,0] / maxVal
试试下面的定义:
from keras.layers.merge import concatenate
input_tensor = Input(shape=(timesteps, 5))
# now let's split tensors
split_1 = Lambda(lambda x: x[:, :, :1])(input_tensor)
split_2 = Lambda(lambda x: x[:, :, 1:])(input_tensor)
split_1 = BatchNormalization()(split_1)
# now let's concatenate them again
follow = concatenate([split_1, split_2])
但正如 Daniel 在他的评论中提到的那样 - 最好对数据进行规范化以处理此类不一致 - 使用 BatchNormalization
可能会导致性能下降。