Keras Seq2Seq 简介
Keras Seq2Seq Introduction
几周前发布了 Keras 对 Seq2Seq 模型的介绍here。我不太理解这段代码的一部分:
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _= decoder_lstm(decoder_inputs,initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
此处定义了decoder_lstm
。是维度latent_dim
的一层。我们使用编码器的状态作为解码器的 initial_state。
我不明白的是为什么在 LSTM 层之后添加一个密集层,为什么它起作用?
由于 return_sequences = True
,解码器应该 return 所有序列,那么在工作之后添加密集层怎么可能?
我想我在这里漏掉了什么。
虽然常见情况使用 2D 数据 (batch,dim)
作为密集层的输入,但在较新版本的 Keras 中,您可以使用 3D 数据 (batch,timesteps,dim)
。
如果您不展平此 3D 数据,您的 Dense 图层的行为就好像它会应用于每个时间步长一样。你会得到像 (batch,timesteps,dense_units)
这样的输出
您可以检查下面的这两个小模型并确认独立于时间步长,两个 Dense 层具有相同数量的参数,表明其参数仅适用于最后一个维度。
from keras.layers import *
from keras.models import Model
import keras.backend as K
#model with time steps
inp = Input((7,12))
out = Dense(5)(inp)
model = Model(inp,out)
model.summary()
#model without time steps
inp2 = Input((12,))
out2 = Dense(5)(inp2)
model2 = Model(inp2,out2)
model2.summary()
结果在两种情况下都会显示 65 (12*5 + 5) 个参数。
几周前发布了 Keras 对 Seq2Seq 模型的介绍here。我不太理解这段代码的一部分:
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _= decoder_lstm(decoder_inputs,initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
此处定义了decoder_lstm
。是维度latent_dim
的一层。我们使用编码器的状态作为解码器的 initial_state。
我不明白的是为什么在 LSTM 层之后添加一个密集层,为什么它起作用?
由于 return_sequences = True
,解码器应该 return 所有序列,那么在工作之后添加密集层怎么可能?
我想我在这里漏掉了什么。
虽然常见情况使用 2D 数据 (batch,dim)
作为密集层的输入,但在较新版本的 Keras 中,您可以使用 3D 数据 (batch,timesteps,dim)
。
如果您不展平此 3D 数据,您的 Dense 图层的行为就好像它会应用于每个时间步长一样。你会得到像 (batch,timesteps,dense_units)
您可以检查下面的这两个小模型并确认独立于时间步长,两个 Dense 层具有相同数量的参数,表明其参数仅适用于最后一个维度。
from keras.layers import *
from keras.models import Model
import keras.backend as K
#model with time steps
inp = Input((7,12))
out = Dense(5)(inp)
model = Model(inp,out)
model.summary()
#model without time steps
inp2 = Input((12,))
out2 = Dense(5)(inp2)
model2 = Model(inp2,out2)
model2.summary()
结果在两种情况下都会显示 65 (12*5 + 5) 个参数。