在 R 中使用 kNN 进行交叉验证时如何构建混淆矩阵

How to construct a confusion matrix when cross-validating with k-NN in R

我试图在其他论坛上提出这个问题,但没有收到任何答复,我在这里重新制定它,并更具体地说明我的问题。我有一个看起来像这样的数据集:

> head(knnresults)
   ACTIVITY_X ACTIVITY_Y ACTIVITY_Z classification
1:         40         47         62        Feeding
2:         60         74         95       Standing
3:         62         63         88       Standing
4:         60         56         82       Standing
5:         66         61         90       Standing
6:         60         53         80       Standing

classification 具有三个不同的类别 FeedingStandingForaging

我现在正在选择一个最佳 k 值,这就是我将 20% 的数据分类并使用其他 80% 作为训练的原因。分类基于前三列中的值。将选择显示最高准确度的 k 值用于以后的分类分析。

这是我一直在使用的脚本:

library(ISLR)
library(caret)
library(lattice)
library(ggplot2)

# Split the data for cross validation:
indxTrain <- createDataPartition(y = knnresults$classification,p = 0.8,list = FALSE)
training <- knnresults[indxTrain,]
testing <- knnresults[-indxTrain,]

# Run k-NN:
set.seed(400)
ctrl <- trainControl(method="repeatedcv",repeats = 3)
knnFit <- train(classification ~ ., data = training, method = "knn", trControl = ctrl, preProcess = c("center","scale"),tuneLength = 20)
knnFit

#Plotting different k values against accuracy (based on repeated cross validation)
plot(knnFit)

首先,让我道歉,因为我是 R 的新手,我不确定这个脚本的合法性。如果发现错误,我将很乐意接受任何更正建议。

其次,如何访问基于此代码的分类混淆矩阵?这对于计算与分类相关的性能指标很重要。

如果有帮助,我可以dput()下面的数据集:

> dput(knnresults)
structure(list(ACTIVITY_X = c(40L, 60L, 62L, 60L, 66L, 60L, 57L, 
54L, 52L, 93L, 80L, 14L, 61L, 51L, 40L, 20L, 21L, 5L, 53L, 48L, 
73L, 73L, 21L, 29L, 63L, 59L, 57L, 51L, 53L, 67L, 72L, 74L, 70L, 
60L, 74L, 85L, 77L, 68L, 58L, 80L, 34L, 45L, 34L, 60L, 75L, 62L, 
66L, 51L, 53L, 48L, 62L, 62L, 57L, 5L, 1L, 12L, 23L, 5L, 4L, 
0L, 13L, 45L, 44L, 31L, 68L, 88L, 43L, 70L, 18L, 83L, 71L, 67L, 
75L, 74L, 49L, 90L, 44L, 64L, 57L, 22L, 29L, 52L, 37L, 32L, 120L, 
45L, 22L, 54L, 30L, 9L, 27L, 14L, 3L, 29L, 12L, 61L, 60L, 29L, 
15L, 7L, 6L, 0L, 2L, 0L, 4L, 1L, 7L, 0L, 0L, 0L, 0L, 0L, 1L, 
23L, 49L, 46L, 8L, 31L, 45L, 60L, 37L, 61L, 52L, 51L, 38L, 86L, 
60L, 41L, 43L, 40L, 42L, 42L, 48L, 64L, 71L, 59L, 0L, 27L, 12L, 
3L, 0L, 0L, 8L, 21L, 6L, 2L, 7L, 4L, 3L, 3L, 46L, 46L, 59L, 53L, 
37L, 44L, 39L, 49L, 37L, 47L, 17L, 36L, 32L, 33L, 26L, 12L, 8L, 
31L, 35L, 27L, 27L, 24L, 17L, 35L, 39L, 28L, 54L, 5L, 0L, 0L, 
0L, 0L, 17L, 22L, 25L, 12L, 0L, 5L, 41L, 51L, 66L, 39L, 32L, 
53L, 43L, 40L, 44L, 45L, 48L, 51L, 41L, 45L, 39L, 46L, 59L, 31L, 
5L, 24L, 18L, 5L, 15L, 13L, 0L, 26L, 0L), ACTIVITY_Y = c(47L, 
74L, 63L, 56L, 61L, 53L, 40L, 41L, 49L, 32L, 54L, 13L, 99L, 130L, 
38L, 14L, 6L, 5L, 94L, 96L, 38L, 43L, 29L, 47L, 66L, 47L, 38L, 
31L, 36L, 35L, 38L, 72L, 54L, 44L, 45L, 51L, 80L, 48L, 39L, 85L, 
42L, 39L, 37L, 75L, 36L, 45L, 32L, 35L, 41L, 26L, 99L, 163L, 
124L, 0L, 0L, 24L, 37L, 0L, 6L, 0L, 29L, 29L, 26L, 27L, 54L, 
147L, 82L, 98L, 12L, 83L, 97L, 104L, 128L, 81L, 42L, 102L, 60L, 
79L, 58L, 15L, 14L, 75L, 75L, 40L, 130L, 40L, 13L, 54L, 42L, 
7L, 10L, 3L, 0L, 15L, 8L, 75L, 55L, 26L, 18L, 1L, 13L, 0L, 0L, 
0L, 1L, 0L, 4L, 0L, 0L, 0L, 0L, 0L, 0L, 17L, 45L, 38L, 10L, 31L, 
52L, 36L, 65L, 97L, 45L, 59L, 49L, 92L, 51L, 34L, 21L, 20L, 29L, 
28L, 22L, 32L, 30L, 86L, 0L, 15L, 7L, 4L, 0L, 0L, 0L, 11L, 3L, 
0L, 1L, 3L, 1L, 0L, 72L, 62L, 98L, 55L, 26L, 39L, 28L, 81L, 20L, 
52L, 12L, 48L, 24L, 40L, 30L, 5L, 6L, 40L, 37L, 33L, 26L, 17L, 
14L, 39L, 27L, 28L, 67L, 0L, 0L, 0L, 0L, 0L, 10L, 12L, 14L, 7L, 
0L, 2L, 39L, 67L, 74L, 28L, 23L, 57L, 34L, 36L, 36L, 37L, 46L, 
43L, 73L, 65L, 31L, 64L, 128L, 17L, 3L, 12L, 17L, 0L, 9L, 7L, 
0L, 17L, 0L), ACTIVITY_Z = c(62L, 95L, 88L, 82L, 90L, 80L, 70L, 
68L, 71L, 98L, 97L, 19L, 116L, 140L, 55L, 24L, 22L, 7L, 108L, 
107L, 82L, 85L, 36L, 55L, 91L, 75L, 69L, 60L, 64L, 76L, 81L, 
103L, 88L, 74L, 87L, 99L, 111L, 83L, 70L, 117L, 54L, 60L, 50L, 
96L, 83L, 77L, 73L, 62L, 67L, 55L, 117L, 174L, 136L, 5L, 1L, 
27L, 44L, 5L, 7L, 0L, 32L, 54L, 51L, 41L, 87L, 171L, 93L, 120L, 
22L, 117L, 120L, 124L, 148L, 110L, 65L, 136L, 74L, 102L, 81L, 
27L, 32L, 91L, 84L, 51L, 177L, 60L, 26L, 76L, 52L, 11L, 29L, 
14L, 3L, 33L, 14L, 97L, 81L, 39L, 23L, 7L, 14L, 0L, 2L, 0L, 4L, 
1L, 8L, 0L, 0L, 0L, 0L, 0L, 1L, 29L, 67L, 60L, 13L, 44L, 69L, 
70L, 75L, 115L, 69L, 78L, 62L, 126L, 79L, 53L, 48L, 45L, 51L, 
50L, 53L, 72L, 77L, 104L, 0L, 31L, 14L, 5L, 0L, 0L, 8L, 24L, 
7L, 2L, 7L, 5L, 3L, 3L, 85L, 77L, 114L, 76L, 45L, 59L, 48L, 95L, 
42L, 70L, 21L, 60L, 40L, 52L, 40L, 13L, 10L, 51L, 51L, 43L, 37L, 
29L, 22L, 52L, 47L, 40L, 86L, 5L, 0L, 0L, 0L, 0L, 20L, 25L, 29L, 
14L, 0L, 5L, 57L, 84L, 99L, 48L, 39L, 78L, 55L, 54L, 57L, 58L, 
66L, 67L, 84L, 79L, 50L, 79L, 141L, 35L, 6L, 27L, 25L, 5L, 17L, 
15L, 0L, 31L, 0L), classification = c("Feeding", "Standing", 
"Standing", "Standing", "Standing", "Standing", "Feeding", "Feeding", 
"Feeding", "Standing", "Standing", "Foraging", "Standing", "Standing", 
"Feeding", "Foraging", "Foraging", "Foraging", "Standing", "Standing", 
"Standing", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Standing", 
"Standing", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Feeding", "Feeding", "Standing", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Standing", "Foraging", "Foraging", "Foraging", 
"Feeding", "Foraging", "Foraging", "Foraging", "Foraging", "Feeding", 
"Feeding", "Feeding", "Standing", "Standing", "Standing", "Standing", 
"Foraging", "Standing", "Standing", "Standing", "Standing", "Standing", 
"Feeding", "Standing", "Feeding", "Standing", "Standing", "Foraging", 
"Foraging", "Standing", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Feeding", "Feeding", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Standing", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Feeding", "Feeding", 
"Foraging", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Feeding", "Standing", "Standing", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Standing", "Standing", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Standing", "Feeding", 
"Standing", "Feeding", "Feeding", "Feeding", "Feeding", "Standing", 
"Feeding", "Feeding", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Feeding", "Foraging", "Foraging", "Feeding", "Feeding", "Feeding", 
"Standing", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Feeding", "Standing", "Standing", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Feeding", "Feeding", 
"Feeding", "Feeding", "Feeding", "Feeding", "Standing", "Feeding", 
"Foraging", "Foraging", "Foraging", "Foraging", "Foraging", "Foraging", 
"Foraging", "Foraging", "Foraging")), row.names = c(NA, -215L
), class = c("data.table", "data.frame"), .internal.selfref = <pointer: 0x0000000002531ef0>)

非常感谢任何意见!

这是一个可重现的例子:

library(caret)
train_set<-createDataPartition(iris$Species,p=0.8,list=FALSE)
valid_set<-iris[-train_set,]
train_set<-iris[train_set,]
ctrl<-trainControl(method="cv",number=5)
set.seed(233)
mk<-train(Species~.,data=train_set,method="knn",trControl = ctrl,metric="Accuracy")

获取混淆矩阵。理想情况下,最好将您的训练与测试或验证集的 predicted 值进行比较。

编辑: 要检索 table,只需执行以下操作:

confusionMatrix(mk)["table"]
$table
            Reference
Prediction       setosa versicolor  virginica
  setosa     33.3333333  0.0000000  0.0000000
  versicolor  0.0000000 32.5000000  2.5000000
  virginica   0.0000000  0.8333333 30.8333333

原创

 confusionMatrix(mk)

结果:

Cross-Validated (5 fold) Confusion Matrix 

(entries are percentual average cell counts across resamples)

            Reference
Prediction   setosa versicolor virginica
  setosa       33.3        0.0       0.0
  versicolor    0.0       31.7       1.7
  virginica     0.0        1.7      31.7

 Accuracy (average) : 0.9667