R:使用 regr.svm 任务在 mlr 中使用新因子水平进行预测
R: Predicting with new factor levels in mlr with regr.svm task
我正在使用 mlr
包从 SVM 进行预测。如果我的验证集包含我的训练数据中不存在的因子水平,则无论我在制作 SVM 学习器时如何设置 fix.factors.prediction
,预测都会失败。
处理这个问题的正确方法是什么?使用 e1071::svm()
将 return 响应新的因子水平,但我如何使用 mlr
方法做同样的事情?
例子
library(mlr)
library(dplyr)
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4/5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")
svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)
svm_mod <- train(svm_lrn, iris_task)
# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : arguments imply differing number of rows: 29, 20
使用 makeLearner("regr.svm", fix.factors.prediction = FALSE)
时,调用 predict
时出现以下错误:
Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'
有用的东西
我可以在子集化到训练集中的因子水平时生成预测:
predict(svm_mod, newdata = valid_set %>%
filter(Species %in% train_set$Species))
使用不同的学习器时没有错误:
nnet_lrn <- makeLearner("regr.nnet", fix.factors.prediction = TRUE)
nnet_mod <- train(nnet_lrn, iris_task)
predict(nnet_mod, newdata = valid_set)
或者直接从包中使用相同的学习器时:
e1071_mod <-
e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
Session 信息
R version 3.4.4 (2018-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 14.04.6 LTS
Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.0
LAPACK: /usr/lib/lapack/liblapack.so.3.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] dplyr_0.8.0.1 mlr_2.14.0.9000 ParamHelpers_1.12
loaded via a namespace (and not attached):
[1] parallelMap_1.4 Rcpp_1.0.1 pillar_1.4.1
[4] compiler_3.4.4 class_7.3-14 tools_3.4.4
[7] tibble_2.1.3 gtable_0.3.0 checkmate_1.9.3
[10] lattice_0.20-38 pkgconfig_2.0.2 rlang_0.3.99.9003
[13] Matrix_1.2-14 fastmatch_1.1-0 rstudioapi_0.8
[16] yaml_2.2.0 parallel_3.4.4 e1071_1.7-1
[19] nnet_7.3-12 grid_3.4.4 tidyselect_0.2.5
[22] glue_1.3.1 data.table_1.12.2 R6_2.4.0
[25] XML_3.98-1.20 survival_2.41-3 ggplot2_3.2.0.9000
[28] purrr_0.3.2 magrittr_1.5 backports_1.1.4
[31] scales_1.0.0.9000 BBmisc_1.11 splines_3.4.4
[34] assertthat_0.2.1 colorspace_1.3-2 stringi_1.4.3
[37] lazyeval_0.2.2 munsell_0.5.0 crayon_1.3.4
好的,这有点挑战性。先说几件事:
e1071::svm()
无法处理 newdata
(Error in predict.svm: test data does not match model) 中的缺失因子水平
- 您的示例的手动执行仅运行,因为您没有在
train_data
中删除未使用的因子水平
- 参数
fix.factor.predictions
没有做它应该做的事。我在 this branch 中发布了一个临时修复程序。
该修复程序非常脏,只是概念证明。我可能会清理它。
non-working 手动执行的证明:
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#> method from
#> [.quosures rlang
#> c.quosures rlang
#> print.quosures rlang
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# this is important
train_set = droplevels(train_set)
e1071_mod <- e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
#> Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'
由 reprex package (v0.3.0)
于 2019-06-13 创建
使用 mlr 中提供的修复的工作示例:
remotes::install_github("mlr-org/mlr@fix-factors")
#> Downloading GitHub repo mlr-org/mlr@fix-factors
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#> method from
#> [.quosures rlang
#> c.quosures rlang
#> print.quosures rlang
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# this is important
train_set = droplevels(train_set)
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")
svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)
svm_mod <- train(svm_lrn, iris_task)
# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
#> Prediction: 30 observations
#> predict.type: response
#> threshold:
#> time: 0.00
#> truth response
#> 1 0.3 0.2457751
#> 2 0.1 0.2730398
#> 3 0.2 0.2717464
#> 4 0.1 0.2717748
#> 5 0.1 0.2651599
#> 6 0.4 0.2582568
#> ... (#rows: 30, #cols: 2)
由 reprex package (v0.3.0)
于 2019-06-13 创建
我正在使用 mlr
包从 SVM 进行预测。如果我的验证集包含我的训练数据中不存在的因子水平,则无论我在制作 SVM 学习器时如何设置 fix.factors.prediction
,预测都会失败。
处理这个问题的正确方法是什么?使用 e1071::svm()
将 return 响应新的因子水平,但我如何使用 mlr
方法做同样的事情?
例子
library(mlr)
library(dplyr)
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4/5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")
svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)
svm_mod <- train(svm_lrn, iris_task)
# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
Error in (function (..., row.names = NULL, check.rows = FALSE, check.names = TRUE, : arguments imply differing number of rows: 29, 20
使用 makeLearner("regr.svm", fix.factors.prediction = FALSE)
时,调用 predict
时出现以下错误:
Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'
有用的东西
我可以在子集化到训练集中的因子水平时生成预测:
predict(svm_mod, newdata = valid_set %>%
filter(Species %in% train_set$Species))
使用不同的学习器时没有错误:
nnet_lrn <- makeLearner("regr.nnet", fix.factors.prediction = TRUE)
nnet_mod <- train(nnet_lrn, iris_task)
predict(nnet_mod, newdata = valid_set)
或者直接从包中使用相同的学习器时:
e1071_mod <-
e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
Session 信息
R version 3.4.4 (2018-03-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 14.04.6 LTS
Matrix products: default
BLAS: /usr/lib/libblas/libblas.so.3.0
LAPACK: /usr/lib/lapack/liblapack.so.3.0
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
[3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C
[11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] dplyr_0.8.0.1 mlr_2.14.0.9000 ParamHelpers_1.12
loaded via a namespace (and not attached):
[1] parallelMap_1.4 Rcpp_1.0.1 pillar_1.4.1
[4] compiler_3.4.4 class_7.3-14 tools_3.4.4
[7] tibble_2.1.3 gtable_0.3.0 checkmate_1.9.3
[10] lattice_0.20-38 pkgconfig_2.0.2 rlang_0.3.99.9003
[13] Matrix_1.2-14 fastmatch_1.1-0 rstudioapi_0.8
[16] yaml_2.2.0 parallel_3.4.4 e1071_1.7-1
[19] nnet_7.3-12 grid_3.4.4 tidyselect_0.2.5
[22] glue_1.3.1 data.table_1.12.2 R6_2.4.0
[25] XML_3.98-1.20 survival_2.41-3 ggplot2_3.2.0.9000
[28] purrr_0.3.2 magrittr_1.5 backports_1.1.4
[31] scales_1.0.0.9000 BBmisc_1.11 splines_3.4.4
[34] assertthat_0.2.1 colorspace_1.3-2 stringi_1.4.3
[37] lazyeval_0.2.2 munsell_0.5.0 crayon_1.3.4
好的,这有点挑战性。先说几件事:
e1071::svm()
无法处理newdata
(Error in predict.svm: test data does not match model) 中的缺失因子水平
- 您的示例的手动执行仅运行,因为您没有在
train_data
中删除未使用的因子水平
- 参数
fix.factor.predictions
没有做它应该做的事。我在 this branch 中发布了一个临时修复程序。 该修复程序非常脏,只是概念证明。我可能会清理它。
non-working 手动执行的证明:
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#> method from
#> [.quosures rlang
#> c.quosures rlang
#> print.quosures rlang
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# this is important
train_set = droplevels(train_set)
e1071_mod <- e1071::svm(Petal.Width ~ Sepal.Length + Sepal.Width +
Petal.Length + Species, train_set)
predict(e1071_mod, newdata = valid_set)
#> Error in scale.default(newdata[, object$scaled, drop = FALSE], center = object$x.scale$"scaled:center", : length of 'center' must equal the number of columns of 'x'
由 reprex package (v0.3.0)
于 2019-06-13 创建使用 mlr 中提供的修复的工作示例:
remotes::install_github("mlr-org/mlr@fix-factors")
#> Downloading GitHub repo mlr-org/mlr@fix-factors
library(mlr)
#> Loading required package: ParamHelpers
#> Registered S3 methods overwritten by 'ggplot2':
#> method from
#> [.quosures rlang
#> c.quosures rlang
#> print.quosures rlang
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
set.seed(575)
data(iris)
# Split data
train_set <- sample_frac(iris, 4 / 5)
valid_set <- setdiff(iris, train_set)
# Remove all "setosa" values from the training set
train_set[train_set$Species == "setosa", "Species"] <-
sample(c("virginica", "versicolor"),
sum(train_set$Species == "setosa"), replace = TRUE)
# this is important
train_set = droplevels(train_set)
# Fit model
iris_task <- makeRegrTask(data = train_set, target = "Petal.Width")
svm_lrn <- makeLearner("regr.svm", fix.factors.prediction = TRUE)
svm_mod <- train(svm_lrn, iris_task)
# Predict on new factor levels
predict(svm_mod, newdata = valid_set)
#> Prediction: 30 observations
#> predict.type: response
#> threshold:
#> time: 0.00
#> truth response
#> 1 0.3 0.2457751
#> 2 0.1 0.2730398
#> 3 0.2 0.2717464
#> 4 0.1 0.2717748
#> 5 0.1 0.2651599
#> 6 0.4 0.2582568
#> ... (#rows: 30, #cols: 2)
由 reprex package (v0.3.0)
于 2019-06-13 创建