在 Sequence2Sequence keras 模型上手动施加语法规则
Imposing grammar rules manually on Sequence2Sequence keras model
我在 keras 中有一个相当标准的序列到序列转换器,它看起来像这样:
# create model
encoder_inputs = Input(shape=(None,))
en_x= Embedding(num_encoder_tokens, EMBEDDING_SIZE)(encoder_inputs)
encoder = LSTM(50, return_state=True)
encoder_outputs, state_h, state_c = encoder(en_x)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dex= Embedding(num_decoder_tokens, EMBEDDING_SIZE)
final_dex= dex(decoder_inputs)
decoder_lstm = LSTM(50, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(final_dex, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
我知道这不是个好主意,但我要翻译的数据不是口头语言,我想对解码序列施加进一步的规则,即
“任何单词都应该在解码序列中只出现一次”等等。该规则不适用于正在编码的序列。
我用来训练模型的数据确实已经遵守了这个规则,但模型的当前输出没有。 (我知道这条规则在语言方面并没有真正意义)
有没有办法做到这一点,如果有的话怎么做?
为什么不在解码器中检查重复的单词然后在出现时停止解码。在解码器的char = target_index_word[word_index] decoded_sentence += ' '+char部分添加规则
def get_predicted_sentence(input_seq):
# Encode the input as state vectors.
enc_output, enc_h, enc_c = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1,1))
# Populate the first character of target sequence with the start character.
target_seq[0, 0] = target_word_index['sos']
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ""
count=0
while not stop_condition:
count+=1
if count>1000:
print('count exceeded')
stop_condition=True
output_words, dec_h, dec_c = decoder_model.predict([target_seq] + [enc_output, enc_h, enc_c ])
#print(output_tokens)
word_index = np.argmax(output_words[0, -1, :])
char=""
if word_index in target_index_word:
char = target_index_word[word_index]
decoded_sentence += ' '+char
print(decoded_sentence)
else:
stop_condition=True
if char == 'eos' or len(decoded_sentence) >= max_input_len:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1,1))
target_seq[0, 0] = word_index
print(target_seq[0,0])
# Update states
enc_h, enc_c = dec_h, dec_c
return decoded_sentence
我在 keras 中有一个相当标准的序列到序列转换器,它看起来像这样:
# create model
encoder_inputs = Input(shape=(None,))
en_x= Embedding(num_encoder_tokens, EMBEDDING_SIZE)(encoder_inputs)
encoder = LSTM(50, return_state=True)
encoder_outputs, state_h, state_c = encoder(en_x)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]
# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dex= Embedding(num_decoder_tokens, EMBEDDING_SIZE)
final_dex= dex(decoder_inputs)
decoder_lstm = LSTM(50, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(final_dex, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
我知道这不是个好主意,但我要翻译的数据不是口头语言,我想对解码序列施加进一步的规则,即 “任何单词都应该在解码序列中只出现一次”等等。该规则不适用于正在编码的序列。
我用来训练模型的数据确实已经遵守了这个规则,但模型的当前输出没有。 (我知道这条规则在语言方面并没有真正意义)
有没有办法做到这一点,如果有的话怎么做?
为什么不在解码器中检查重复的单词然后在出现时停止解码。在解码器的char = target_index_word[word_index] decoded_sentence += ' '+char部分添加规则
def get_predicted_sentence(input_seq):
# Encode the input as state vectors.
enc_output, enc_h, enc_c = encoder_model.predict(input_seq)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1,1))
# Populate the first character of target sequence with the start character.
target_seq[0, 0] = target_word_index['sos']
# Sampling loop for a batch of sequences
# (to simplify, here we assume a batch of size 1).
stop_condition = False
decoded_sentence = ""
count=0
while not stop_condition:
count+=1
if count>1000:
print('count exceeded')
stop_condition=True
output_words, dec_h, dec_c = decoder_model.predict([target_seq] + [enc_output, enc_h, enc_c ])
#print(output_tokens)
word_index = np.argmax(output_words[0, -1, :])
char=""
if word_index in target_index_word:
char = target_index_word[word_index]
decoded_sentence += ' '+char
print(decoded_sentence)
else:
stop_condition=True
if char == 'eos' or len(decoded_sentence) >= max_input_len:
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1,1))
target_seq[0, 0] = word_index
print(target_seq[0,0])
# Update states
enc_h, enc_c = dec_h, dec_c
return decoded_sentence