如何在 tidymodels 框架下提取我的分类器对单个数据点的预测?

How can I extract my classifier's predictions on individual data points under the tidymodels framework?

我在做一个文本分类项目,我一直在tidymodels框架下做所有事情。现在,我正在尝试调查特定数据点是否一直被全面错误标记。为此,我想进入单个样本的已保存预测。当我执行重采样并使用 collect_predictions 时,虽然我看到一个包含每个数据点的预测标签和实际标签的列表,但数据点本身的身份仍然隐藏。有一列可以追溯到 (.row),但我无法确认这一点。

我一直在生成我的重采样策略如下:

grades_split <- initial_split(tabled_texts2, strata = grade)

grades_train <- training(grades_split)
grades_test <- testing(grades_split)

folds <- vfold_cv(grades_train)

然后,在调整和拟合模型之后,我生成了重采样对象:

fitted_grades <- fit(final_wf, grades_train)

LR_rs <- fit_resamples(
  fitted_grades,
  folds,
  control = control_resamples(save_pred = TRUE)
)

最后,我检查了这样的预测:

predictions <- collect_predictions(LR_rs)
View(predictions)

我得到一个 table,看起来像这样:

id .pred_4 .pred_not 4 .行 .pred_class 等级 .config
折叠01 0.502905 0.497095 18 4 4 Preprocessor1_Model1
折叠01 0.484647 0.515353 22 不是 4 4 Preprocessor1_Model1
折叠01 0.481496 0.518504 23 不是 4 4 Preprocessor1_Model1
折叠01 0.492314 0.507686 40 不是 4 4 Preprocessor1_Model1
折叠01 0.477215 0.522785 52 不是 4 4 Preprocessor1_Model1

如何将这些值映射回原始数据?

这是一个类似的代表。在这个例子中,我希望能够具体看到哪些企鹅被错误分类,而不仅仅是任意的 .row 值(我很确定它不会映射回 1-1 到原始数据集)

library(tidyverse)
library(tidymodels)
library(tidytext)
library(modeldata)
library(naivebayes)
library(discrim)
set.seed(1)

data("penguins")
View(penguins)
nb_spec <- naive_Bayes() %>%
  set_mode('classification') %>%
  set_engine('naivebayes')

fitted_wf <- workflow() %>%
  add_formula(species ~ island + flipper_length_mm) %>%
  add_model(nb_spec) %>%
  fit(penguins)


split <- initial_split(penguins)

train <- training(split)
test <- testing(split)

folds <- vfold_cv(train)

NB_rs <- fit_resamples(
  fitted_wf,
  folds,
  control = control_resamples(save_pred = TRUE)
)
predictions <- collect_predictions(NB_rs)
View(predictions)

.row 列实际上会告诉您每个预测来自训练数据集的哪一行。让我们看看我们能否说服您:

library(tidyverse)
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
set.seed(1)

data("penguins")
nb_spec <- naive_Bayes() %>%
   set_mode('classification') %>%
   set_engine('naivebayes')

fitted_wf <- workflow() %>%
   add_formula(species ~ island + flipper_length_mm) %>%
   add_model(nb_spec) 

split <- penguins %>%
   na.omit() %>%
   initial_split()

penguin_train <- training(split)
penguin_test <- testing(split)

folds <- vfold_cv(penguin_train)

NB_rs <- fit_resamples(
   fitted_wf,
   folds,
   control = control_resamples(save_pred = TRUE)
)

predictions <- collect_predictions(NB_rs)

让我们只看其中一个折叠:

predictions %>% filter(id == "Fold01")
#> # A tibble: 25 × 8
#>    id     .pred_Adelie .pred_Chinstrap .pred_Gentoo  .row .pred_class species  
#>    <chr>         <dbl>           <dbl>        <dbl> <int> <fct>       <fct>    
#>  1 Fold01     0.609        0.391        0.000000526     3 Adelie      Adelie   
#>  2 Fold01     0.182        0.818        0.000104        8 Chinstrap   Adelie   
#>  3 Fold01     0.423        0.577        0.000000325     9 Chinstrap   Chinstrap
#>  4 Fold01     0.999        0.00120      0.00000137     21 Adelie      Adelie   
#>  5 Fold01     0.000178     0.0000310    1.00           27 Gentoo      Gentoo   
#>  6 Fold01     0.552        0.448        0.000000395    36 Adelie      Adelie   
#>  7 Fold01     0.997        0.000392     0.00275        45 Adelie      Adelie   
#>  8 Fold01     0.000211     0.000000780  1.00           48 Gentoo      Gentoo   
#>  9 Fold01     0.998        0.00129      0.00114        60 Adelie      Adelie   
#> 10 Fold01     0.00313      0.000100     0.997          79 Gentoo      Gentoo   
#> # … with 15 more rows, and 1 more variable: .config <chr>

这有第 3、8、9 行等。它是 folds 中第一个重采样的 评估 集。

现在让我们看一下训练数据:

penguin_train
#> # A tibble: 249 × 7
#>    species   island bill_length_mm bill_depth_mm flipper_length_mm body_mass_g
#>    <fct>     <fct>           <dbl>         <dbl>             <int>       <int>
#>  1 Chinstrap Dream            50.2          18.8               202        3800
#>  2 Gentoo    Biscoe           50.2          14.3               218        5700
#>  3 Adelie    Dream            38.1          17.6               187        3425
#>  4 Chinstrap Dream            51            18.8               203        4100
#>  5 Chinstrap Dream            52.7          19.8               197        3725
#>  6 Gentoo    Biscoe           49.6          16                 225        5700
#>  7 Chinstrap Dream            46.2          17.5               187        3650
#>  8 Adelie    Dream            35.7          18                 202        3550
#>  9 Chinstrap Dream            51.7          20.3               194        3775
#> 10 Gentoo    Biscoe           50.4          15.7               222        5750
#> # … with 239 more rows, and 1 more variable: sex <fct>

reprex package (v2.0.0)

于 2021-07-30 创建

看第3、8、9行; species 匹配,因为它们是相同的行!

请注意,您可能会对 folds 中的每个折叠得到不同的预测,因为它们有不同的训练集,我们称之为 分析 集。