在 Sequence2Sequence keras 模型上手动施加语法规则

Imposing grammar rules manually on Sequence2Sequence keras model

我在 keras 中有一个相当标准的序列到序列转换器,它看起来像这样:

# create model 

encoder_inputs = Input(shape=(None,))
en_x=  Embedding(num_encoder_tokens, EMBEDDING_SIZE)(encoder_inputs)
encoder = LSTM(50, return_state=True)
encoder_outputs, state_h, state_c = encoder(en_x)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]


# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dex=  Embedding(num_decoder_tokens, EMBEDDING_SIZE)
final_dex= dex(decoder_inputs)

decoder_lstm = LSTM(50, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(final_dex, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)


model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])

我知道这不是个好主意,但我要翻译的数据不是口头语言,我想对解码序列施加进一步的规则,即 “任何单词都应该在解码序列中只出现一次”等等。该规则不适用于正在编码的序列。

我用来训练模型的数据确实已经遵守了这个规则,但模型的当前输出没有。 (我知道这条规则在语言方面并没有真正意义)

有没有办法做到这一点,如果有的话怎么做?

为什么不在解码器中检查重复的单词然后在出现时停止解码。在解码器的char = target_index_word[word_index] decoded_sentence += ' '+char部分添加规则

def get_predicted_sentence(input_seq):
    # Encode the input as state vectors.
    enc_output, enc_h, enc_c = encoder_model.predict(input_seq)
    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1,1))
    
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0] = target_word_index['sos']
    
    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ""
    
    count=0
    while not stop_condition:
        count+=1
        if count>1000:
            print('count exceeded')
            stop_condition=True
        output_words, dec_h, dec_c = decoder_model.predict([target_seq] + [enc_output, enc_h, enc_c ])
        #print(output_tokens)
        word_index = np.argmax(output_words[0, -1, :])
        char=""
        if word_index in target_index_word:
            char = target_index_word[word_index]
            decoded_sentence += ' '+char
            print(decoded_sentence)
        else:
            stop_condition=True
        if char == 'eos' or len(decoded_sentence) >= max_input_len:
            stop_condition = True
        
        # Update the target sequence (of length 1).
        target_seq = np.zeros((1,1))
        target_seq[0, 0] = word_index
        print(target_seq[0,0])
        # Update states
        enc_h, enc_c = dec_h, dec_c
    
    return decoded_sentence