如何在 tidymodels 框架下提取我的分类器对单个数据点的预测?
How can I extract my classifier's predictions on individual data points under the tidymodels framework?
我在做一个文本分类项目,我一直在tidymodels框架下做所有事情。现在,我正在尝试调查特定数据点是否一直被全面错误标记。为此,我想进入单个样本的已保存预测。当我执行重采样并使用 collect_predictions 时,虽然我看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然隐藏。有一列可以追溯到 (.row),但我无法确认这一点。
我一直在生成我的重采样策略如下:
grades_split <- initial_split(tabled_texts2, strata = grade)
grades_train <- training(grades_split)
grades_test <- testing(grades_split)
folds <- vfold_cv(grades_train)
然后,在调整和拟合模型之后,我生成了重采样对象:
fitted_grades <- fit(final_wf, grades_train)
LR_rs <- fit_resamples(
fitted_grades,
folds,
control = control_resamples(save_pred = TRUE)
)
最后,我检查了这样的预测:
predictions <- collect_predictions(LR_rs)
View(predictions)
我得到一个 table,看起来像这样:
id
.pred_4
.pred_not 4
.行
.pred_class
等级
.config
折叠01
0.502905
0.497095
18
4
4
Preprocessor1_Model1
折叠01
0.484647
0.515353
22
不是 4
4
Preprocessor1_Model1
折叠01
0.481496
0.518504
23
不是 4
4
Preprocessor1_Model1
折叠01
0.492314
0.507686
40
不是 4
4
Preprocessor1_Model1
折叠01
0.477215
0.522785
52
不是 4
4
Preprocessor1_Model1
如何将这些值映射回原始数据?
这是一个类似的代表。在这个例子中,我希望能够具体看到哪些企鹅被错误分类,而不仅仅是任意的 .row 值(我很确定它不会映射回 1-1 到原始数据集)
library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)
data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec) %>%
fit(penguins)
split <- initial_split(penguins)
train <- training(split)
test <- testing(split)
folds <- vfold_cv(train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)
.row
列实际上会告诉您每个预测来自训练数据集的哪一行。让我们看看我们能否说服您:
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(discrim)
#>
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#>
#> smoothness
set.seed(1)
data("penguins")
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec)
split <- penguins %>%
na.omit() %>%
initial_split()
penguin_train <- training(split)
penguin_test <- testing(split)
folds <- vfold_cv(penguin_train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
让我们只看其中一个折叠:
predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#> id .pred_Adelie .pred_Chinstrap .pred_Gentoo .row .pred_class species
#> <chr> <dbl> <dbl> <dbl> <int> <fct> <fct>
#> 1 Fold01 0.609 0.391 0.000000526 3 Adelie Adelie
#> 2 Fold01 0.182 0.818 0.000104 8 Chinstrap Adelie
#> 3 Fold01 0.423 0.577 0.000000325 9 Chinstrap Chinstrap
#> 4 Fold01 0.999 0.00120 0.00000137 21 Adelie Adelie
#> 5 Fold01 0.000178 0.0000310 1.00 27 Gentoo Gentoo
#> 6 Fold01 0.552 0.448 0.000000395 36 Adelie Adelie
#> 7 Fold01 0.997 0.000392 0.00275 45 Adelie Adelie
#> 8 Fold01 0.000211 0.000000780 1.00 48 Gentoo Gentoo
#> 9 Fold01 0.998 0.00129 0.00114 60 Adelie Adelie
#> 10 Fold01 0.00313 0.000100 0.997 79 Gentoo Gentoo
#> # … with 15 more rows, and 1 more variable: .config <chr>
这有第 3、8、9 行等。它是 folds
中第一个重采样的 评估 集。
现在让我们看一下训练数据:
penguin_train
#> # A tibble: 249 × 7
#> species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#> <fct> <fct> <dbl> <dbl> <int> <int>
#> 1 Chinstrap Dream 50.2 18.8 202 3800
#> 2 Gentoo Biscoe 50.2 14.3 218 5700
#> 3 Adelie Dream 38.1 17.6 187 3425
#> 4 Chinstrap Dream 51 18.8 203 4100
#> 5 Chinstrap Dream 52.7 19.8 197 3725
#> 6 Gentoo Biscoe 49.6 16 225 5700
#> 7 Chinstrap Dream 46.2 17.5 187 3650
#> 8 Adelie Dream 35.7 18 202 3550
#> 9 Chinstrap Dream 51.7 20.3 194 3775
#> 10 Gentoo Biscoe 50.4 15.7 222 5750
#> # … with 239 more rows, and 1 more variable: sex <fct>
由 reprex package (v2.0.0)
于 2021-07-30 创建
看第3、8、9行; species
匹配,因为它们是相同的行!
请注意,您可能会对 folds
中的每个折叠得到不同的预测,因为它们有不同的训练集,我们称之为 分析 集。
我在做一个文本分类项目,我一直在tidymodels框架下做所有事情。现在,我正在尝试调查特定数据点是否一直被全面错误标记。为此,我想进入单个样本的已保存预测。当我执行重采样并使用 collect_predictions 时,虽然我看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然隐藏。有一列可以追溯到 (.row),但我无法确认这一点。
我一直在生成我的重采样策略如下:
grades_split <- initial_split(tabled_texts2, strata = grade)
grades_train <- training(grades_split)
grades_test <- testing(grades_split)
folds <- vfold_cv(grades_train)
然后,在调整和拟合模型之后,我生成了重采样对象:
fitted_grades <- fit(final_wf, grades_train)
LR_rs <- fit_resamples(
fitted_grades,
folds,
control = control_resamples(save_pred = TRUE)
)
最后,我检查了这样的预测:
predictions <- collect_predictions(LR_rs)
View(predictions)
我得到一个 table,看起来像这样:
id | .pred_4 | .pred_not 4 | .行 | .pred_class | 等级 | .config |
---|---|---|---|---|---|---|
折叠01 | 0.502905 | 0.497095 | 18 | 4 | 4 | Preprocessor1_Model1 |
折叠01 | 0.484647 | 0.515353 | 22 | 不是 4 | 4 | Preprocessor1_Model1 |
折叠01 | 0.481496 | 0.518504 | 23 | 不是 4 | 4 | Preprocessor1_Model1 |
折叠01 | 0.492314 | 0.507686 | 40 | 不是 4 | 4 | Preprocessor1_Model1 |
折叠01 | 0.477215 | 0.522785 | 52 | 不是 4 | 4 | Preprocessor1_Model1 |
如何将这些值映射回原始数据?
这是一个类似的代表。在这个例子中,我希望能够具体看到哪些企鹅被错误分类,而不仅仅是任意的 .row 值(我很确定它不会映射回 1-1 到原始数据集)
library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)
data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec) %>%
fit(penguins)
split <- initial_split(penguins)
train <- training(split)
test <- testing(split)
folds <- vfold_cv(train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)
.row
列实际上会告诉您每个预测来自训练数据集的哪一行。让我们看看我们能否说服您:
library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(discrim)
#>
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#>
#> smoothness
set.seed(1)
data("penguins")
nb_spec <- naive_Bayes() %>%
set_mode('classification') %>%
set_engine('naivebayes')
fitted_wf <- workflow() %>%
add_formula(species ~ island + flipper_length_mm) %>%
add_model(nb_spec)
split <- penguins %>%
na.omit() %>%
initial_split()
penguin_train <- training(split)
penguin_test <- testing(split)
folds <- vfold_cv(penguin_train)
NB_rs <- fit_resamples(
fitted_wf,
folds,
control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
让我们只看其中一个折叠:
predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#> id .pred_Adelie .pred_Chinstrap .pred_Gentoo .row .pred_class species
#> <chr> <dbl> <dbl> <dbl> <int> <fct> <fct>
#> 1 Fold01 0.609 0.391 0.000000526 3 Adelie Adelie
#> 2 Fold01 0.182 0.818 0.000104 8 Chinstrap Adelie
#> 3 Fold01 0.423 0.577 0.000000325 9 Chinstrap Chinstrap
#> 4 Fold01 0.999 0.00120 0.00000137 21 Adelie Adelie
#> 5 Fold01 0.000178 0.0000310 1.00 27 Gentoo Gentoo
#> 6 Fold01 0.552 0.448 0.000000395 36 Adelie Adelie
#> 7 Fold01 0.997 0.000392 0.00275 45 Adelie Adelie
#> 8 Fold01 0.000211 0.000000780 1.00 48 Gentoo Gentoo
#> 9 Fold01 0.998 0.00129 0.00114 60 Adelie Adelie
#> 10 Fold01 0.00313 0.000100 0.997 79 Gentoo Gentoo
#> # … with 15 more rows, and 1 more variable: .config <chr>
这有第 3、8、9 行等。它是 folds
中第一个重采样的 评估 集。
现在让我们看一下训练数据:
penguin_train
#> # A tibble: 249 × 7
#> species island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#> <fct> <fct> <dbl> <dbl> <int> <int>
#> 1 Chinstrap Dream 50.2 18.8 202 3800
#> 2 Gentoo Biscoe 50.2 14.3 218 5700
#> 3 Adelie Dream 38.1 17.6 187 3425
#> 4 Chinstrap Dream 51 18.8 203 4100
#> 5 Chinstrap Dream 52.7 19.8 197 3725
#> 6 Gentoo Biscoe 49.6 16 225 5700
#> 7 Chinstrap Dream 46.2 17.5 187 3650
#> 8 Adelie Dream 35.7 18 202 3550
#> 9 Chinstrap Dream 51.7 20.3 194 3775
#> 10 Gentoo Biscoe 50.4 15.7 222 5750
#> # … with 239 more rows, and 1 more variable: sex <fct>
由 reprex package (v2.0.0)
于 2021-07-30 创建看第3、8、9行; species
匹配,因为它们是相同的行!
请注意,您可能会对 folds
中的每个折叠得到不同的预测,因为它们有不同的训练集,我们称之为 分析 集。