Fastai 文本分类器:对看不见的数据进行批量预测
Fastai text classifier: batch prediction on unseen data
我一直在使用 fastai 的文本分类器 (https://docs.fast.ai/text.html)。我目前预测看不见的短语的情绪(正面或负面)如下:
def _unpack_prediction(self, text) -> Tuple[bool, float]:
out = self._model.predict(text)
return str(out[0]) == "positive", max(out[2][0].item(), out[2][1].item())
def example(self, messages: Sequence[str]):
results = map(self._unpack_prediction, messages)
for phrase, out in zip(messages, results):
print(f"{phrase[:100]}...[{'pos' if out[0] else 'neg'}] - [{out[1]:.2f}]")
给定一个短语列表:
("I love this movie",
"The actors are good, but this movie is definitely stupid",
"There is no plot at all!!! Just special effects ")
结果是:
I love this movie...[pos] - [1.00]
The actors are good, but this movie is definitely stupid...[neg] - [0.96]
There is no plot at all!!! Just special effects ...[neg] - [0.95]
但是,按顺序对短语应用预测非常慢。
有没有办法在不创建测试数据集的情况下使用 fastai 库进行批量预测?
你当然可以。这是执行此操作的示例代码
test_df = pd.read_csv(path_to_test_csv_file)
learn.data.add_test(test_df[target_col_name])
prob_preds = learn.get_preds(ds_type=DatasetType.Test, ordered=True)
我一直在使用 fastai 的文本分类器 (https://docs.fast.ai/text.html)。我目前预测看不见的短语的情绪(正面或负面)如下:
def _unpack_prediction(self, text) -> Tuple[bool, float]:
out = self._model.predict(text)
return str(out[0]) == "positive", max(out[2][0].item(), out[2][1].item())
def example(self, messages: Sequence[str]):
results = map(self._unpack_prediction, messages)
for phrase, out in zip(messages, results):
print(f"{phrase[:100]}...[{'pos' if out[0] else 'neg'}] - [{out[1]:.2f}]")
给定一个短语列表:
("I love this movie",
"The actors are good, but this movie is definitely stupid",
"There is no plot at all!!! Just special effects ")
结果是:
I love this movie...[pos] - [1.00]
The actors are good, but this movie is definitely stupid...[neg] - [0.96]
There is no plot at all!!! Just special effects ...[neg] - [0.95]
但是,按顺序对短语应用预测非常慢。
有没有办法在不创建测试数据集的情况下使用 fastai 库进行批量预测?
你当然可以。这是执行此操作的示例代码
test_df = pd.read_csv(path_to_test_csv_file)
learn.data.add_test(test_df[target_col_name])
prob_preds = learn.get_preds(ds_type=DatasetType.Test, ordered=True)