使用自定义分类器通过 GridSearchCV 和 OneVsRestClassifier 进行多标签分类

Using custom classifier for mutilabel classification with GridSearchCV and OneVsRestClassifier

我正在尝试使用 OneVsRestClassifier 对一组评论进行多标签分类。我的 objective 是将每个评论标记为可能的主题列表。我的自定义分类器使用手动整理的单词列表及其在 csv 中的相应标签来标记每个评论。我正在尝试结合使用 VotingClassifier 的词袋技术和我的自定义分类器获得的结果。这是我现有代码的一部分:

import numpy as np

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.ensemble import VotingClassifier
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.grid_search import GridSearchCV
from sklearn.linear_model import SGDClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MultiLabelBinarizer

class CustomClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, word_to_tag):
        self.word_to_tag = word_to_tag

    def fit(self, X, y=None):
        return self

    def predict_proba(self, X):
        prob = np.zeros(shape=(len(self.word_to_tag), 2))

        for index, comment in np.ndenumerate(X):
            prob[index] = [0.5, 0.5]
            for word, label in self.word_to_tag.iteritems():
                if (label == self.class_label) and (comment.find(word) >= 0):
                    prob[index] = [0, 1]
                    break

        return prob

    def _get_label(self, ...):
        # Need to have a way of knowing which label being classified
        # by OneVsRestClassifier (self.class_label)

bow_clf = Pipeline([('vect', CountVectorizer(stop_words='english', min_df=1, max_df=0.9)), 
                    ('tfidf', TfidfTransformer(use_idf=False)),
                    ('clf', SGDClassifier(loss='log', penalty='l2', alpha=1e-3, n_iter=5)),
                   ])
custom_clf = CustomClassifier(word_to_tag_dict)

ovr_clf = OneVsRestClassifier(VotingClassifier(estimators=[('bow', bow_clf), ('custom', custom_clf)],
                                               voting='soft'))

params = { 'estimator_weights': ([1, 1], [1, 2], [2, 1]) }
gs_clf = GridSearchCV(ovr_clf, params, n_jobs=-1, verbose=1, scoring='precision_samples')

binarizer = MultiLabelBinarizer()

gs_clf.fit(X, binarizer.fit_transform(y))

我的目的是使用这个通过多种启发式方法获得的手动整理的单词列表来改进仅应用词袋获得的结果。目前我正在努力寻找一种方法来在预测时知道哪个标签被分类,因为使用 OneVsRestClassifier 为每个标签创建了 CustomClassifier 的副本。

我认为您正在寻找 classes_ 属性:http://scikit-learn.org/dev/modules/generated/sklearn.multiclass.OneVsRestClassifier.html#sklearn.multiclass.OneVsRestClassifier