无法使用 geom_contour() 在 R 中绘制决策边界

Unable to plot Decision Boundary in R with geom_contour()

我知道有几个类似的问题,但我仍在努力理解它背后的概念。

我根据 knn 预测 类(“真”和“假”):

hype_knn_prediction = knn(hype_quant_train[,2:3], hype_quant_test[,2:3], cl = hype_quant_train$Selected, k = k)

而且我还可以绘制测试数据:

ggplot(hype_quant_test, aes(x= Comments, y= Votings, color = Selected, shape = Selected)) + 
  geom_point(size = 3) +
  labs(y="Votes", x = "Comments")+
  ggtitle("Testdata")+
  theme(plot.title = element_text(hjust = 0.5))+
  theme(legend.position="bottom")

Plot of test data

现在我想将 类 的决策边界添加到测试数据的图中。因此,我将 hype_knn_prediction 作为一列添加到 hype_quant_test 数据框中。但是当我将 geom_contour(data=hype_quant_test, aes(x=Comments, y=Votings, z= as.numberic(Selected)), breaks=c(0,.5)) 添加到情节时,我收到以下消息: stat_contour() 中的计算失败:x 坐标的数量必须与密度矩阵中的列数匹配。

我该如何解决这个问题?我假设我必须转换一些数据但不知道如何

编辑

训练数据:

   Selected  Votings          Comments
1      true  0.2348563517    0.162454874
2     false  0.0027691243    0.001805054
3     false  0.0136725511    0.027075812
4     false  0.1128418138    0.077617329
5     false  0.0529595016    0.016245487
6     false  0.0190377293    0.012635379
7     false  0.0231914157    0.001805054
8     false  0.3367947387    0.019855596
9     false  0.0036344756    0.005415162
10    false  0.0051921080    0.005415162
11    false  0.0202492212    0.014440433
12    false  0.0178262375    0.007220217
13    false  0.0029421945    0.010830325
14    false  0.0680166147    0.036101083
15    false  0.0053651783    0.003610108
16    false  0.2397023191    0.034296029
17    false  0.0001730703    0.000000000
18    false  0.0228452752    0.023465704
19    false  0.0129802700    0.000000000
20    false  0.0192107996    0.018050542
21    false  0.0010384216    0.000000000
22    false  0.0129802700    0.005415162
23    false  0.0000000000    0.000000000
24    false  0.0134994808    0.003610108
25    false  0.0742471443    0.039711191
26    false  0.0256143994    0.009025271
27    false  0.0039806161    0.001805054
28     true  0.4110418830    0.050541516
29    false  0.0114226376    0.063176895
30    false  0.0185185185    0.016245487
31    false  0.0051921080    0.003610108
32    false  0.1952232606    0.021660650
33    false  0.1138802354    0.012635379
34    false  0.0048459675    0.016245487
35    false  0.0242298373    0.009025271
36    false  0.0167878159    0.001805054
37    false  0.0039806161    0.001805054
38     true  0.7727587400    0.146209386
39    false  0.0154032537    0.000000000
40    false  0.0057113188    0.007220217
41    false  0.0038075459    0.000000000
42    false  0.0046728972    0.001805054
43    false  0.0152301835    0.003610108
44    false  0.0408445829    0.025270758
45    false  0.0131533403    0.007220217
46    false  0.0578054690    0.037906137
47    false  0.0046728972    0.005415162
48    false  0.0001730703    0.001805054
49    false  0.1169955002    0.122743682
50    false  0.0044998269    0.003610108
51    false  0.0000000000    0.000000000
52    false  0.1439944618    0.036101083
53    false  0.0072689512    0.005415162
54    false  0.0064035999    0.009025271
55    false  0.0614399446    0.027075812
56    false  0.0719972309    0.005415162
57     true  0.3418137764    0.018050542
58    false  0.0117687781    0.012635379
59    false  0.0072689512    0.014440433
60     true  0.0313257182    0.018050542
61    false  0.1021114573    0.019855596
62    false  0.0024229837    0.003610108
63    false  0.0072689512    0.000000000
64    false  0.0169608861    0.003610108
65    false  0.0340948425    0.014440433
66     true  0.7069920388    0.332129964
67     true  0.7377985462    0.175090253
68    false  0.0919003115    0.007220217
69    false  0.0065766701    0.001805054
70    false  0.0401523018    0.027075812
71    false  0.0223260644    0.005415162
72    false  0.0635167878    0.018050542
73    false  0.0013845621    0.000000000
74    false  0.0060574593    0.000000000
75     true  0.6102457598    0.909747292
76    false  0.0022499135    0.001805054
77    false  0.0316718588    0.007220217
78    false  0.0019037729    0.000000000
79     true  1.0000000000    1.000000000
80    false  0.0240567670    0.016245487

添加预测列后的测试数据:

   Selected   Votings          Comments   Prediction
1     false   0.329525787    0.023465704      false
2     false   0.299930772    0.075812274      false
3      true   0.962443752    0.178700361       true
4     false   0.032191070    0.001805054      false
5     false   0.036863967    0.025270758      false
6     false   0.014884043    0.005415162      false
7     false   0.034787124    0.005415162      false
8     false   0.007615092    0.000000000      false
9     false   0.005538249    0.000000000      false
10    false   0.006403600    0.005415162      false
11    false   0.006749740    0.005415162      false
12    false   0.048286604    0.072202166      false
13    false   0.057286258    0.021660650      false
14    false   0.067324334    0.012635379      false
15    false   0.004153686    0.001805054      false
16    false   0.004845967    0.003610108      false
17    false   0.089131187    0.055956679      false
18    false   0.010384216    0.001805054      false
19    false   0.040671513    0.021660650      false
20    false   0.001903773    0.001805054      false

编辑

我尝试使用默认的 Iris 数据测试绘图,但消息仍然相同:

library(datasets)
library(class)
library(ggplot2)
library(caret)

iris_df = as.data.frame(iris)

normalize = function(x) {
  return ((x - min(x)) / (max(x) - min(x)))
}

iris_df$Sepal.Length = normalize(iris_df$Sepal.Length)
iris_df$Sepal.Width = normalize(iris_df$Sepal.Width)
iris_df$Petal.Length = normalize(iris_df$Petal.Length)
iris_df$Petal.Width = normalize(iris_df$Petal.Width)

set.seed(1234)

#sampling
sample = sample(nrow(iris_df), nrow(iris_df)*0.8, replace = FALSE)

#Training data
iris_train=iris_df[sample,]

#Testdata
iris_test=iris_df[-sample,]


ggplot(iris_train, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  theme(legend.position="bottom")

k = round(sqrt(nrow(iris_train)))


knn_predict = knn(iris_train[,1:4], iris_test[,1:4], cl = iris_train$Species, k = k)

iris_test$Prediction = knn_predict

confusionMatrix(iris_test$Prediction, iris_test$Species)

ggplot(iris_test, aes(x= Sepal.Length, y= Sepal.Width, color = Species, shape = Species)) + 
  geom_point(size = 3) +
  geom_contour(data = iris_test, aes(x= Sepal.Length, y= Sepal.Width, z = as.numeric(Prediction)),breaks = c(0,.5))

stat_contour()计算失败: x 坐标的数量必须与密度矩阵中的列数匹配。

我认为 geom_contour 方法的一个关键部分是对变量矩阵执行预测,而我没有在您的 iris 代码中看到这一点。这是你如何制作一个。我没有对预处理做任何花哨的事情。

library(class)
library(ggplot2)

train <- sample(150, 75)

train_dat <- iris[train, -5]
test_dat <- iris[-train, -5]

vars <- c("Sepal.Width", "Sepal.Length")

# First make a grid
n <- 40
pred.mat <- expand.grid(
  Sepal.Width = with(iris, seq(min(Sepal.Width), max(Sepal.Width), length.out = n)),
  Sepal.Length = with(iris, seq(min(Sepal.Length), max(Sepal.Length), length.out = n))
)

# Then ask for prediction on the grid
pred.mat$pred <- knn(train_dat[, vars], pred.mat, cl = iris$Species[train], k = 3)

# Use grid as input for geom_contour
ggplot(pred.mat, aes(Sepal.Width, Sepal.Length)) +
  geom_point(data = iris, aes(color = Species)) +
  geom_contour(aes(z = as.numeric(pred == "setosa"), 
                   colour = "setosa"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "virginica"), 
                   colour = "virginica"), 
               breaks = 0.5) +
  geom_contour(aes(z = as.numeric(pred == "versicolor"), 
                   color = "versicolor"), 
               breaks = 0.5)

reprex package (v0.3.0)

于 2020-04-18 创建