如何在 countvectorizer 中循环 ngrams 的参数?

How to loop the parameter for ngrams inside countvectorizer?

我想为 ngrams 尝试 6 种不同的组合,即:

  1. unigram (1,1)
  2. bigram (2,2)
  3. trigram (3,3)
  4. unigram + bigram (1,2)
  5. bigram + trigram (2,3)
  6. unigram + bigram + trigram (1,3)

是否可以使用for循环或任何其他方式循环遍历所有组合,而不是一个一个地更改参数?

pipeline = Pipeline([
('vect', CountVectorizer(tokenizer=no_tokenizer, lowercase=False, binary=True, ngram_range=(1,1))),
('clf', SGDClassifier(loss='log', penalty='l2', max_iter=20, verbose=0))
])
pipeline.fit(train.X, train.y)
preds = pipeline.predict(dev.X)
print(metrics.classification_report(dev.y, preds))

我也想从 print(metrics.classification_report(dev.y, preds)) 获得 6 种不同组合的所有输出。

我认为最干净的方法是将 GridSearchCV 与选定的 "param_grid" 一起使用,但这需要您选择特定的评分函数。此处描述了访问特定参数的语法 https://scikit-learn.org/stable/modules/compose.html“5.1.1.1.3. 嵌套参数”。

from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import GridSearchCV


pipeline = Pipeline([
    ('vect', CountVectorizer(tokenizer=no_tokenizer, lowercase=False, binary=True)),
    ('clf', SGDClassifier(loss='log', penalty='l2', max_iter=20, verbose=0))
])

param_grid = {'vect__n_gram_range': [(1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (1, 3)]}
grid_search = GridSearchCV(pipeline, cv=1, param_grid=param_grid, scoring='f1')

grid_search.fit(train.X, train.y)
grid_search.score(dev.X, dev.y)

如果您真的很想为每个可能的 n_gram_range 获取完整的分类报告,您可以执行以下操作

from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import SGDClassifier


pipeline = Pipeline([
    ('vect', CountVectorizer(tokenizer=no_tokenizer, lowercase=False, binary=True)),
    ('clf', SGDClassifier(loss='log', penalty='l2', max_iter=20, verbose=0))
])

for n_gram_range in [(1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (1, 3)]:
    pipeline.set_params(vect__n_gram_range=n_gram_range)
    pipeline.fit(train.X, train.y)
    preds = pipeline.predict(dev.X)
    print(metrics.classification_report(dev.y, preds))