训练、调整、交叉验证和测试 Ranger(随机森林)分位数回归模型?

Training, Tuning, Cross-Validating, and Testing Ranger (Random Forest) Quantile Regression Model?

有人可以分享如何训练、调整(超参数)、交叉验证和测试 ranger 分位数回归模型,以及错误评估吗?使用 iris 或波士顿住房数据集?

我问的原因是因为我无法在 Kaggle、随机博客、Youtube 上找到许多使用分位数回归的示例或演练。我遇到的大多数问题都是分类问题。

我目前正在使用分位数回归模型,但我希望看到其他示例,尤其是超参数调整

这个函数有很多参数。由于这不是一个说明所有含义的论坛,我真的建议您使用交叉验证来询问如何以及为什么。 (或寻找可能已经回答的问题。)

library(tidyverse)
library(ranger)
library(caret)
library(funModeling)

data(iris)

#----------- setup data -----------
# this doesn't include exploration or cleaning which are both necessary
summary(iris)
df_status(iris)

#----------------- create training sample ----------------
set.seed(395280469) # for replicability

# create training sample partition (70/20 split)
tr <- createDataPartition(iris$Species, 
                          p = .8, 
                          list = F)

有很多方法可以拆分数据,但我更喜欢 Caret,因为如果您输入的是它们,它们会平衡因素。

#--------- First model ---------
fit.r <- ranger(Sepal.Length ~ ., 
                data = iris[tr, ],
                write.forest = TRUE,
                importance = 'permutation',
                quantreg = TRUE,
                keep.inbag = TRUE,
                replace = FALSE)
fit.r
# Ranger result
# 
# Call:
#  ranger(Sepal.Length ~ ., data = iris[tr, ], write.forest = TRUE,
#     importance = "permutation", quantreg = TRUE, keep.inbag = TRUE, 
#     replace = FALSE) 
# 
# Type:                             Regression 
# Number of trees:                  500 
# Sample size:                      120 
# Number of independent variables:  4 
# Mtry:                             2 
# Target node size:                 5 
# Variable importance mode:         permutation 
# Splitrule:                        variance 
# OOB prediction error (MSE):       0.1199364 
# R squared (OOB):                  0.8336928  

p.r <- predict(fit.r, iris[-tr, -1],
               type = 'quantiles')

默认为 .1、.5 和 .9:

postResample(p.r$predictions[, 1], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.5165946 0.7659124 0.4036667  

postResample(p.r$predictions[, 2], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.3750556 0.7587326 0.3133333  

postResample(p.r$predictions[, 3], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.6488991 0.7461830 0.5703333  

看看这在实践中是什么样子的:

# this performance is the best so far, let's see what it looks like visually
ggplot(data.frame(p.Q1 = p.r$predictions[, 1],
                  p.Q5 = p.r$predictions[, 2],
                  p.Q9 = p.r$predictions[, 3],
                  Actual = iris[-tr, 1])) +
  geom_point(aes(x = Actual, y = p.Q1, color = "P.Q1")) +
  geom_point(aes(x = Actual, y = p.Q5, color = "P.Q5")) +
  geom_point(aes(x = Actual, y = p.Q9, color = "P.Q9")) +
  geom_line(aes(Actual, Actual, color = "Actual")) +
  scale_color_viridis_d(end = .8, "Error",
                        direction = -1)+
  theme_bw()

# since Quantile .1 performed the best
ggplot(data.frame(p.Q9 = p.r$predictions[, 3],
                  Actual = iris[-tr, 1])) +
  geom_point(aes(x = Actual, y = p.Q9, color = "P.Q9")) +
  geom_segment(aes(x = Actual, xend = Actual, 
                   y = Actual, yend = p.Q9)) +
  geom_line(aes(Actual, Actual, color = "Actual")) +
  scale_color_viridis_d(end = .8, "Error",
                        direction = -1)+
  theme_bw()

#------------ ranger model with options --------------
# last call used default 
#    splitrule: variance, use "extratrees" (only 2 for this one)
#    mtry = 2, use 3 this time
#    min.node.size = 5, using 6 this time
#    using num.threads = 15 ** this is the number of cores on YOUR device
#        change accordingly --- if you don't know, drop this one

set.seed(326)
fit.r2 <- ranger(Sepal.Length ~ ., 
                data = iris[tr, ],
                write.forest = TRUE,
                importance = 'permutation',
                quantreg = TRUE,
                keep.inbag = TRUE,
                replace = FALSE,
                splitrule = "extratrees",
                mtry = 3,
                min.node.size = 6,
                num.threads = 15)
fit.r2
# Ranger result
# Type:                             Regression 
# Number of trees:                  500 
# Sample size:                      120 
# Number of independent variables:  4 
# Mtry:                             3 
# Target node size:                 6 
# Variable importance mode:         permutation 
# Splitrule:                        extratrees 
# Number of random splits:          1 
# OOB prediction error (MSE):       0.1107299 
# R squared (OOB):                  0.8464588  

该模型生产类似。

p.r2 <- predict(fit.r2, iris[-tr, -1],
               type = 'quantiles')

postResample(p.r2$predictions[, 1], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.4932883 0.8144309 0.4000000  
 
postResample(p.r2$predictions[, 2], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.3610171 0.7643744 0.3100000  

postResample(p.r2$predictions[, 3], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.6555939 0.8141144 0.5603333 

预测总体上也非常相似。 这不是一组非常大的数据,几乎没有预测变量。 他们贡献了多少?

importance(fit.r2)
#  Sepal.Width Petal.Length  Petal.Width      Species 
#   0.06138883   0.71052453   0.22956522   0.18082998  

#------------ ranger model with options --------------
# drop a predictor, lower mtry, min.node.size
set.seed(326)
fit.r3 <- ranger(Sepal.Length ~ ., 
                 data = iris[tr, -4], # dropped Sepal.Width
                 write.forest = TRUE,
                 importance = 'permutation',
                 quantreg = TRUE,
                 keep.inbag = TRUE,
                 replace = FALSE,
                 splitrule = "extratrees",
                 mtry = 2,            # has to change (var count lower)
                 min.node.size = 4,   # lowered
                 num.threads = 15)
fit.r3
# Ranger result
# Type:                             Regression 
# Number of trees:                  500 
# Sample size:                      120 
# Number of independent variables:  3 
# Mtry:                             2 
# Target node size:                 6 
# Variable importance mode:         permutation 
# Splitrule:                        extratrees 
# Number of random splits:          1 
# OOB prediction error (MSE):       0.1050143 
# R squared (OOB):                  0.8543842  

第二个最重要的预测因子已被删除并有所改善。

p.r3 <- predict(fit.r3, iris[-tr, -c(1, 4)],
                type = 'quantiles')

postResample(p.r3$predictions[, 1], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.4760952 0.8089810 0.3800000  

postResample(p.r3$predictions[, 2], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.3738315 0.7769388 0.3250000  

postResample(p.r3$predictions[, 3], iris[-tr, 1])
#      RMSE  Rsquared       MAE 
# 0.6085584 0.8032592 0.5170000   

importance(fit.r3)
# almost everthing relies on Petal.Length
#  Sepal.Width Petal.Length      Species 
#   0.08008264   0.95440333   0.32570147