SpaCy-transformers 回归输出
SpaCy-transformers regression output
我想要回归输出而不是分类。例如:我想要一个从 0 到 1 的浮点输出值而不是 n 类。
这是包 github 页面中的简约示例:
import spacy
from spacy.util import minibatch
import random
import torch
is_using_gpu = spacy.prefer_gpu()
if is_using_gpu:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
nlp = spacy.load("en_trf_bertbaseuncased_lg")
print(nlp.pipe_names) # ["sentencizer", "trf_wordpiecer", "trf_tok2vec"]
textcat = nlp.create_pipe("trf_textcat", config={"exclusive_classes": True})
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
nlp.add_pipe(textcat)
optimizer = nlp.resume_training()
for i in range(10):
random.shuffle(TRAIN_DATA)
losses = {}
for batch in minibatch(TRAIN_DATA, size=8):
texts, cats = zip(*batch)
nlp.update(texts, cats, sgd=optimizer, losses=losses)
print(i, losses)
nlp.to_disk("/bert-textcat")
有没有一种简单的方法可以让 trf_textcat
作为回归器工作?还是意味着扩展图书馆?
我想出了一个解决方法:从 nlp 管道中提取矢量表示:
vector_repres = nlp('Test text').vector
对所有文本条目执行此操作后,您最终会得到文本的固定维度表示。假设您有连续的输出值,请随意使用任何估计器,包括具有线性输出的神经网络。
请注意,向量表示是文本中所有单词的向量嵌入的平均值 - 对于您的情况,它可能不是最佳解决方案。
我想要回归输出而不是分类。例如:我想要一个从 0 到 1 的浮点输出值而不是 n 类。
这是包 github 页面中的简约示例:
import spacy
from spacy.util import minibatch
import random
import torch
is_using_gpu = spacy.prefer_gpu()
if is_using_gpu:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
nlp = spacy.load("en_trf_bertbaseuncased_lg")
print(nlp.pipe_names) # ["sentencizer", "trf_wordpiecer", "trf_tok2vec"]
textcat = nlp.create_pipe("trf_textcat", config={"exclusive_classes": True})
for label in ("POSITIVE", "NEGATIVE"):
textcat.add_label(label)
nlp.add_pipe(textcat)
optimizer = nlp.resume_training()
for i in range(10):
random.shuffle(TRAIN_DATA)
losses = {}
for batch in minibatch(TRAIN_DATA, size=8):
texts, cats = zip(*batch)
nlp.update(texts, cats, sgd=optimizer, losses=losses)
print(i, losses)
nlp.to_disk("/bert-textcat")
有没有一种简单的方法可以让 trf_textcat
作为回归器工作?还是意味着扩展图书馆?
我想出了一个解决方法:从 nlp 管道中提取矢量表示:
vector_repres = nlp('Test text').vector
对所有文本条目执行此操作后,您最终会得到文本的固定维度表示。假设您有连续的输出值,请随意使用任何估计器,包括具有线性输出的神经网络。
请注意,向量表示是文本中所有单词的向量嵌入的平均值 - 对于您的情况,它可能不是最佳解决方案。