如何在 sklearn 中使用 BERT 和 Elmo 嵌入

How to use BERT and Elmo embedding with sklearn

我使用 sklearn 创建了一个使用 Tf-Idf 的文本分类器,我想使用 BERT 和 Elmo 嵌入而不是 Tf-Idf。

该怎么做?

我正在使用以下代码嵌入 Bert:

from flair.data import Sentence
from flair.embeddings import TransformerWordEmbeddings

# init embedding
embedding = TransformerWordEmbeddings('bert-base-uncased')

# create a sentence
sentence = Sentence('The grass is green .')

# embed words in sentence
embedding.embed(sentence)
import pandas as pd
import numpy as np

from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LogisticRegression

column_trans = ColumnTransformer([
    ('tfidf', TfidfVectorizer(), 'text'),
    ('number_scaler', MinMaxScaler(), ['number'])
])

# Initialize data
data = [
    ['This process, however, afforded me no means of.', 20, 1],
    ['another long description', 21, 1],
    ['It never once occurred to me that the fumbling', 19, 0],
    ['How lovely is spring As we looked from Windsor', 18, 0]
]

# Create DataFrame
df = pd.DataFrame(data, columns=['text', 'number', 'target'])

X = column_trans.fit_transform(df)
X = X.toarray()
y = df.loc[:, "target"].values

# Perform classification

classifier = LogisticRegression(random_state=0)
classifier.fit(X, y)

Sklearn 提供了定制 data transformer(与机器学习模型“变形金刚”无关)的可能性。

我实现了一个使用您使用的 flair 库的自定义 sklearn 数据转换器。请注意,我使用 TransformerDocumentEmbeddings 而不是 TransformerWordEmbeddings。还有一个与 transformers 库一起使用的。

我正在添加一个 SO 问题,讨论使用哪个转换器层很有趣

我不熟悉 Elmo,虽然我发现 this 使用 tensorflow。您可以修改我分享的代码以使 Elmo 正常工作。

import torch
import numpy as np
from flair.data import Sentence
from flair.embeddings import TransformerDocumentEmbeddings
from sklearn.base import BaseEstimator, TransformerMixin


class FlairTransformerEmbedding(TransformerMixin, BaseEstimator):

    def __init__(self, model_name='bert-base-uncased', batch_size=None, layers=None):
        # From https://lvngd.com/blog/spacy-word-vectors-as-features-in-scikit-learn/
        # For pickling reason you should not load models in __init__
        self.model_name = model_name
        self.model_kw_args = {'batch_size': batch_size, 'layers': layers}
        self.model_kw_args = {k: v for k, v in self.model_kw_args.items()
                              if v is not None}
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        model = TransformerDocumentEmbeddings(
                self.model_name, fine_tune=False,
                **self.model_kw_args)

        sentences = [Sentence(text) for text in X]
        embedded = model.embed(sentences)
        embedded = [e.get_embedding().reshape(1, -1) for e in embedded]
        return np.array(torch.cat(embedded).cpu())

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from transformers import AutoTokenizer, AutoModel
from more_itertools import chunked

class TransformerEmbedding(TransformerMixin, BaseEstimator):

    def __init__(self, model_name='bert-base-uncased', batch_size=1, layer=-1):
        # From https://lvngd.com/blog/spacy-word-vectors-as-features-in-scikit-learn/
        # For pickling reason you should not load models in __init__
        self.model_name = model_name
        self.layer = layer
        self.batch_size = batch_size
    
    def fit(self, X, y=None):
        return self
    
    def transform(self, X):
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        model = AutoModel.from_pretrained(self.model_name)

        res = []
        for batch in chunked(X, self.batch_size):
            encoded_input = tokenizer.batch_encode_plus(
                batch, return_tensors='pt', padding=True, truncation=True)
            output = model(**encoded_input)
            embed = output.last_hidden_state[:,self.layer].detach().numpy()
            res.append(embed)

        return np.concatenate(res)

在你的情况下,用这个替换你的列变压器:

column_trans = ColumnTransformer([
    ('embedding', FlairTransformerEmbedding(), 'text'),
    ('number_scaler', MinMaxScaler(), ['number'])
])