Scikit-Learn GridSearchCV 在 gensim LDA 模型上失败
Scikit-Learn GridSearchCV failing on on a gensim LDA model
这是创建模型的代码:
import gensim
NUM_TOPICS = 4
ldamodel = gensim.models.ldamodel.LdaModel(corpus,num_topics =
NUM_TOPICS,id2word=dictionary,passes=100)
ldamodel.save('model5.gensim')
topics = ldamodel.print_topics(num_words=4)
print(topics)
这是 GridSearchCV 的代码:
search_params = {'n_components': [4, 6, 8, 10, 20], 'learning_decay': [.5, .7, .9]}
# Init Grid Search Class
model = GridSearchCV(ldamodel, param_grid=search_params)
# Do the Grid Search
model.fit(data_vectorized)
这是输出:
*---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-108-1a35c49ac19e> in <module>
9
10 # Do the Grid Search
---> 11 model.fit(data_vectorized)
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
627
628 scorers, self.multimetric_ = _check_multimetric_scoring(
--> 629 self.estimator, scoring=self.scoring)
630
631 if self.multimetric_:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in _check_multimetric_scoring(estimator, scoring)
471 if callable(scoring) or scoring is None or isinstance(scoring,
472 str):
--> 473 scorers = {"score": check_scoring(estimator, scoring=scoring)}
474 return scorers, False
475 else:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in check_scoring(estimator, scoring, allow_none)
399 if not hasattr(estimator, 'fit'):
400 raise TypeError("estimator should be an estimator implementing "
--> 401 "'fit' method, %r was passed" % estimator)
402 if isinstance(scoring, str):
403 return get_scorer(scoring)
TypeError: estimator should be an estimator implementing 'fit' method, <gensim.models.ldamodel.LdaModel object at 0x000002121E55D3C8> was passed*
您正在尝试使用 scikit-learn
包中的 GridSearchCV
对象,这需要它所在的模型对象 运行 实现某些方法(如错误消息中所示:fit
方法)。由于 scikit-learn
与 gensim
没有任何关系,您需要通过 subclassing an Estimator
class in scikit-learn
确保它们兼容并在 fit
方法中封装 gensim
训练。
此外,在我看来,the LdaModel
documentation 似乎没有使用您尝试搜索的参数(n_components
、learning_decay
)。您只能搜索模型使用的参数值。
这是创建模型的代码:
import gensim
NUM_TOPICS = 4
ldamodel = gensim.models.ldamodel.LdaModel(corpus,num_topics =
NUM_TOPICS,id2word=dictionary,passes=100)
ldamodel.save('model5.gensim')
topics = ldamodel.print_topics(num_words=4)
print(topics)
这是 GridSearchCV 的代码:
search_params = {'n_components': [4, 6, 8, 10, 20], 'learning_decay': [.5, .7, .9]}
# Init Grid Search Class
model = GridSearchCV(ldamodel, param_grid=search_params)
# Do the Grid Search
model.fit(data_vectorized)
这是输出:
*---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-108-1a35c49ac19e> in <module>
9
10 # Do the Grid Search
---> 11 model.fit(data_vectorized)
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
627
628 scorers, self.multimetric_ = _check_multimetric_scoring(
--> 629 self.estimator, scoring=self.scoring)
630
631 if self.multimetric_:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in _check_multimetric_scoring(estimator, scoring)
471 if callable(scoring) or scoring is None or isinstance(scoring,
472 str):
--> 473 scorers = {"score": check_scoring(estimator, scoring=scoring)}
474 return scorers, False
475 else:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\sklearn\metrics\_scorer.py in check_scoring(estimator, scoring, allow_none)
399 if not hasattr(estimator, 'fit'):
400 raise TypeError("estimator should be an estimator implementing "
--> 401 "'fit' method, %r was passed" % estimator)
402 if isinstance(scoring, str):
403 return get_scorer(scoring)
TypeError: estimator should be an estimator implementing 'fit' method, <gensim.models.ldamodel.LdaModel object at 0x000002121E55D3C8> was passed*
您正在尝试使用 scikit-learn
包中的 GridSearchCV
对象,这需要它所在的模型对象 运行 实现某些方法(如错误消息中所示:fit
方法)。由于 scikit-learn
与 gensim
没有任何关系,您需要通过 subclassing an Estimator
class in scikit-learn
确保它们兼容并在 fit
方法中封装 gensim
训练。
此外,在我看来,the LdaModel
documentation 似乎没有使用您尝试搜索的参数(n_components
、learning_decay
)。您只能搜索模型使用的参数值。