训练 Doc2Vec 模型后准确率低

Low accuracy rate after training Doc2Vec model

我正在尝试训练 Doc2Vec 模型以创建多标签文本分类器。
为此,我选择了一个包含大约 70000 篇文章的数据集,每篇文章包含 1500 到 2000 个单词。
这些文章分为 5 类.
在设置我的输入时,我选择了相应标签作为文档的标签。 我做了如下: tagged_article = data.apply(lambda r: TaggedDocument(words=r['article'].split(), tags=[r.labels]), axis=1)
然后我用以下行代码训练了我的模型:

model_dbow = Doc2Vec(dm=1, vector_size=300, negative=5, min_count=10, workers=cores)
model_dbow.build_vocab([x for x in tqdm(tagged_article.values)])

print("Training the Doc2Vec model for ", no_epochs, "number of epochs" )
for epoch in range(no_epochs):
     model_dbow.train(utils.shuffle([x for x in tqdm(tagged_article.values)]),total_examples=len(tagged_article.values), epochs=1)
     model_dbow.alpha -= 0.002
     model_dbow.min_alpha = model_dbow.alpha   

之后我创建了一个逻辑回归模型来预测每篇文章的标签。

为此,我创建了以下函数:\

def vec_for_learning(model, tagged_docs):
sents = tagged_docs.values
targets, regressors = zip(*[(doc.tags[0], model.infer_vector(doc.words, steps=inference_steps)) for doc in tqdm(sents)])
return targets, regressors

y_train, X_train = vec_for_learning(model_dbow, tagged_article)

logreg = LogisticRegression(solver='lbfgs',max_iter=1000)
logreg.fit(X_train, y_train)

不幸的是,我得到了一个非常糟糕的结果。事实上,我得到 22% 的准确率和 21% 的 F1 分数

你能解释一下为什么我会得到这些糟糕的结果吗?

首先,您几乎肯定不想在自己管理 alpha 的同时使用自己的循环多次调用 train()。参见:

因为你没有显示你的 no_epochs 值,我不能确定你做了绝对最糟糕的事情 - 最终将 alpha 减少到负值 - 但你可能.尽管如此,仍然不需要那个容易出错的循环。 (而且,您可能想联系向您推荐此代码模板的任何来源,并让他们知道他们正在推广一种反模式。)

使用您仅有的 5 个已知标签作为文档标签也可能是错误的。这意味着该模型基本上只学习 5 个文档向量,就好像所有文章都只是 5 个巨型文本的片段。虽然有时使用(或添加)已知标签作为标签很有帮助,但更 classic 的训练方式 Doc2Vec 为每个文档提供了一个唯一的 ID,因此模型正在学习(在您的情况下)关于70,000 个不同的文档向量,并且可以更丰富地模拟所有文档和标签以各种不规则形状跨越的文档可能性空间。

虽然您的数据肯定与显示 Doc2Vec 算法价值的已发表作品相当,但您的语料库并不庞大(而且不清楚您的词汇量可能有多大和多样化)。因此,对于您拥有的 quanitity/variety 数据,300 个维度可能过大,或者 min_count=10 在修剪不太重要和不太好的采样词时过于激进(或不够激进)。

最后,请注意 Doc2Vec class 将继承默认的 epochs 值 5,但大多数已发表的作品使用 10-20 个训练周期,并且通常使用更小的数据集甚至更多可能会有所帮助。此外,推理将在模型创建时重用相同的 epochs 集(或默认设置),并且在(至少)与训练相同数量的 epochs 时效果最好 - 虽然不清楚 inference_steps你正在使用。

(作为代码易读性的一个单独问题:您已将模型命名为 model_dbow,但是通过使用 dm=1 您实际上使用的是 PV-DM 模式, 不是 PV-DBOW 模式。)