使用 createDataPartition 的分层抽样将小 类 排除在测试之外

Stratified sampling using createDataPartition drops small classes out of test

我正在尝试进行分层抽样,我意识到当我的 类 案例很少时,我最终得到的测试数据集可能没有这些少数案例 类.

这是一些示例代码

library(caret)

# data set for debugging in RStudio
data("imports85")
input<-imports85
    
# settings
set.seed(1)
dependent <- make.names("make")
training.share <- 0.75
impute <- "no"
type <- "classification"

# save original column names for later and make R-friendly column names
original.names <- names(input)
names(input) <- make.names(original.names)
    
# create train and test data sets
input.labelled <- input[complete.cases(input[,dependent]),] #split off rows w/o dependent
if (impute=="no") { 
    input.clean <- input.labelled[complete.cases(input.labelled),] #drop cases w/ missing variables
} else if (impute=="yes") {
    input.clean <- rfImpute(input.labelled[,dependent] ~ .,input.labelled)[,-1] #or impute missing variables and remove added duplicate of dependent column
}

train.index <- createDataPartition(input.clean[,dependent], p=training.share, list=FALSE) #create row index for train data set using stratified sampling but very small classes might all go into train?!
rf.train <- input.clean[train.index,] #create train data set
rf.test <- input.clean[-train.index,] #create test data set from left-overs
if (type=="classification") { #balance train data set for classification (can be skipped if upsampling takes place as part of tuning settings cntrl)
    rf.train <- upSample(x=rf.train[, names(rf.train) != dependent], y=rf.train[, names(rf.train) == dependent], yname=dependent)
}

# define variables Y and dependent x
Y.train <- rf.train[, names(rf.train) == dependent]
x.train <- rf.train[, names(rf.train) != dependent]
Y.test <- rf.test[, names(rf.test) == dependent]
x.test <- rf.test[, names(rf.test) != dependent]

# train single RF model
rf <- randomForest(x.train, y=Y.train, xtest=x.test, ytest=Y.test, type=type, keep.forest=TRUE)

您会收到来自 createDataPartition 的警告,您会看到例如“make”==chevrolet 在 rf.train 中有 3 个案例,在 [=24= 中有 none ],这可能会导致 randomForest.

下游出现问题

有什么聪明的方法可以避免 w/o 将数据从火车泄漏到测试中?

很多都是一样的,但不是全部。

相同:

这是因为你的因变量。您选择了 make。你检查过这个领域吗?你有培训和测试;你把只有一个观察结果的结果放在哪里,比如make = "mercury"?你怎么能用那个训练?如果你不为此训练,你怎么能测试它?

input %>% 
  group_by(make) %>% 
  summarise(count = n()) %>% 
  arrange(count) %>% 
  print(n = 22)

# # A tibble: 22 × 2
#    make        count
#    <fct>       <int>
#  1 mercury         1
#  2 renault         2
#  3 alfa-romero     3
#  4 chevrolet       3
#  5 jaguar          3
#  6 isuzu           4
#  7 porsche         5
#  8 saab            6
#  9 audi            7
# 10 plymouth        7
# 11 bmw             8
# 12 mercedes-benz   8
# 13 dodge           9
# 14 peugot         11
# 15 volvo          11
# 16 subaru         12
# 17 volkswagen     12
# 18 honda          13
# 19 mitsubishi     13
# 20 mazda          17
# 21 nissan         18
# 22 toyota         32

当您执行函数时createDataPartition(), you also had warnings。我认为 randomForest 套餐要求每组至少五个。您可以筛选要包含的组,并将该数据用于测试和培训。

在标记为 settings 的评论之前,您可以添加以下内容以对组进行子集化并验证结果。

filtGrps <- input %>% 
  group_by(make) %>% 
  summarise(count = n()) %>% 
  filter(count >=5) %>% 
  select(make) %>% 
  unlist()

# filter for groups with sufficient observations for package
input <- input %>% 
  filter(make %in% filtGrps) %>% 
  droplevels() # then drop the empty levels

# check to see if it filtered as expected
input %>% 
  group_by(make) %>% 
  summarise(count = n()) %>% 
  arrange(-count) %>% 
  print(n = 16)

这只使用了 5 个,这并不理想。 (越多越好。)

这里改

caret 模型中,您使用了插补。你没有为这个模型这样做。您在创建 input.clean 时删除了另外 34 个观察值。那时...

# you removed another 34 rows- need to check the classes, again
# you imputed for caret/train
input.clean %>% 
  group_by(make) %>% 
  summarise(count = n()) %>% 
  arrange(-count) %>% 
  print(n = 16)
# # A tibble: 16 × 2
#    make          count
#    <fct>         <int>
#  1 toyota           31
#  2 nissan           18
#  3 honda            13
#  4 subaru           12
#  5 mazda            11
#  6 volvo            11
#  7 mitsubishi       10
#  8 dodge             8
#  9 volkswagen        8
# 10 peugot            7
# 11 plymouth          6
# 12 saab              6
# 13 mercedes-benz     5
# 14 audi              4
# 15 bmw               4
# 16 porsche           1 

你现在需要再丢三个 类。

# there is an exclamation point to negate this
input.clean <- input.clean %>% 
  filter(!make %in% c("audi", "bmw", "porsche")) %>% 
  droplevels()

# validate changes
input.clean %>% 
  group_by(make) %>% 
  summarise(count = n()) %>% 
  arrange(-count) %>% 
  print(n = 16)
# 13 classes now

从这里开始,您的代码就可以使用了。

rf
# 
# Call:
#  randomForest(x = x.train, y = Y.train, xtest = x.test, ytest = Y.test,      keep.forest = TRUE, type = type) 
#                Type of random forest: classification
#                      Number of trees: 500
# No. of variables tried at each split: 5
# 
#         OOB estimate of  error rate: 1.92%
# Confusion matrix:
#               dodge honda mazda mercedes-benz mitsubishi nissan peugot
# dodge            24     0     0             0          0      0      0
# honda             0    22     0             0          2      0      0
# mazda             0     0    24             0          0      0      0
# mercedes-benz     0     0     0            24          0      0      0
# mitsubishi        0     0     0             0         23      0      0
# nissan            0     0     0             0          0     23      0
# peugot            0     0     0             0          0      0     24
# plymouth          0     0     0             0          0      0      0
# saab              0     0     0             0          0      0      0
# subaru            0     0     0             0          0      0      0
# toyota            0     0     0             0          0      1      0
# volkswagen        0     0     0             0          0      0      0
# volvo             0     0     0             0          0      0      0
#               plymouth saab subaru toyota volkswagen volvo class.error
# dodge                0    0      0      0          0     0  0.00000000
# honda                0    0      0      0          0     0  0.08333333
# mazda                0    0      0      0          0     0  0.00000000
# mercedes-benz        0    0      0      0          0     0  0.00000000
# mitsubishi           1    0      0      0          0     0  0.04166667
# nissan               0    0      0      1          0     0  0.04166667
# peugot               0    0      0      0          0     0  0.00000000
# plymouth            24    0      0      0          0     0  0.00000000
# saab                 0   24      0      0          0     0  0.00000000
# subaru               0    0     24      0          0     0  0.00000000
# toyota               0    0      0     22          0     1  0.08333333
# volkswagen           0    0      0      0         24     0  0.00000000
# volvo                0    0      0      0          0    24  0.00000000
#                 Test set error rate: 3.23%

提示 - 如果您在同一个脚本文件中进行这些调用,请在模型之间使用唯一的对象名称,这样,您始终知道哪个对象中有哪些数据。它可能是导致各种问题的隐藏错误。