获得 LSTM 的前 3 个预测,而不仅仅是前 3 个
Get top 3 prediction of LSTM instead of only the top
我有一个针对文本内容训练的 LSTM 模型。现在我想用那个模型来生成一些句子。但是我不是总是选择最好的选项,而是希望它从例如前 3 个中选择 select,这样它就可以用相同的输入生成不同的句子,因为现在我几乎对每个输入都得到相同的答案。我该如何修改这段代码,这样才有可能,我知道我需要删除 np.argmax
但我不知道如何删除 return 前 3 个最高值的索引。
当前代码:
def prediction(seed_text, next_words):
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_seq_length-1, padding='pre')
predicted = np.argmax(model.predict(token_list, verbose=0), axis=-1)
ouput_word = ""
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += ' '+output_word
return seed_text
np.argsort
将按照从小到大的顺序为您提供数组中项目的索引:https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
这是一个使用 argsort
的示例。请注意,预测值最低的那个(索引 2,预测值为 0.05 的“c”)被排除在打印的内容之外。
import numpy as np
word_index = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
predictions = np.array([0.1, 0.7, 0.05, 0.15])
# add negative to sort large to small; slice to select just up to 3rd index
top_3 = np.argsort(-predictions)[:3]
for word, index in word_index.items():
if index in top_3:
print(word)
#> a
#> b
#> d
我有一个针对文本内容训练的 LSTM 模型。现在我想用那个模型来生成一些句子。但是我不是总是选择最好的选项,而是希望它从例如前 3 个中选择 select,这样它就可以用相同的输入生成不同的句子,因为现在我几乎对每个输入都得到相同的答案。我该如何修改这段代码,这样才有可能,我知道我需要删除 np.argmax
但我不知道如何删除 return 前 3 个最高值的索引。
当前代码:
def prediction(seed_text, next_words):
for _ in range(next_words):
token_list = tokenizer.texts_to_sequences([seed_text])[0]
token_list = pad_sequences([token_list], maxlen=max_seq_length-1, padding='pre')
predicted = np.argmax(model.predict(token_list, verbose=0), axis=-1)
ouput_word = ""
for word, index in tokenizer.word_index.items():
if index == predicted:
output_word = word
break
seed_text += ' '+output_word
return seed_text
np.argsort
将按照从小到大的顺序为您提供数组中项目的索引:https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
这是一个使用 argsort
的示例。请注意,预测值最低的那个(索引 2,预测值为 0.05 的“c”)被排除在打印的内容之外。
import numpy as np
word_index = {'a': 0, 'b': 1, 'c': 2, 'd': 3}
predictions = np.array([0.1, 0.7, 0.05, 0.15])
# add negative to sort large to small; slice to select just up to 3rd index
top_3 = np.argsort(-predictions)[:3]
for word, index in word_index.items():
if index in top_3:
print(word)
#> a
#> b
#> d