在 R 中绘制图形

Plotting a Graph in R

我用这个网站作为参考 https://www.r-bloggers.com/2021/02/how-to-build-a-handwritten-digit-classifier-with-r-and-random-forests/

使用 R 和随机森林编写手写数字分类器。

是否可以构建在代码末尾获得的 colMeans 的图? MNIST 训练和测试数据集(您可以在上面的 link 中找到)没有任何列标题。 我是 R 的新手,还在学习。任何形式的帮助将不胜感激。

代码如下:

library(readr)

#loading the train and test sets of MNIST dataset 
train_set <- read_csv("mnist_train.csv", col_names = FALSE)
test_set <- read_csv("mnist_test.csv", col_names = FALSE)

#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[, 1]$X1)
test_labels <- as.factor(test_set[, 1]$X1)

#printing the first 10 labels
head(train_labels, 10)

#printing number of records for each digit (0 to 9)
summary(train_labels)

#importing random forest
library(randomForest)

#training the model
rf <- randomForest(x = train_set, y = train_labels, xtest = test_set, ntree = 50)
rf

#1- error rate
#represents the accuracy 
1 - mean(rf$err.rate)

#importing dplyr
library(dplyr)

#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
err_df %>%
    select(-"OOB") %>%
    colMeans()

colMeans 的输出1

我通过对训练集和测试集进行相当多的子集化来稍微修改您的代码以加快分析速度。您可以自由 comment/delete 相关行。请查看下面的代码并告诉我这是否是您要找的。

library(readr)
#importing dplyr
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
#importing random forest
library(randomForest)
#> randomForest 4.6-14
#> Type rfNews() to see new features/changes/bug fixes.
#> 
#> Attaching package: 'randomForest'
#> The following object is masked from 'package:dplyr':
#> 
#>     combine
library(ggplot2)
#> 
#> Attaching package: 'ggplot2'
#> The following object is masked from 'package:randomForest':
#> 
#>     margin


#loading the train and test sets of MNIST dataset 
train_set <- read_csv("~/Downloads/mnist_train.csv", col_names = FALSE)
#> 
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#>   .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.
test_set <- read_csv("~/Downloads/mnist_test.csv", col_names = FALSE)
#> 
#> ── Column specification ────────────────────────────────────────────────────────
#> cols(
#>   .default = col_double()
#> )
#> ℹ Use `spec()` for the full column specifications.

#extracting the labels
#converting digits to factors
train_labels <- as.factor(train_set[, 1]$X1)
test_labels <- as.factor(test_set[, 1]$X1)

#printing the first 10 labels
head(train_labels, 10)
#>  [1] 5 0 4 1 9 2 1 3 1 4
#> Levels: 0 1 2 3 4 5 6 7 8 9

#printing number of records for each digit (0 to 9)
summary(train_labels)
#>    0    1    2    3    4    5    6    7    8    9 
#> 5923 6742 5958 6131 5842 5421 5918 6265 5851 5949

# reducing size
train_set <- train_set[ 1:1000, ]
train_labels <- train_labels[ 1:1000 ]
test_set <- test_set[ 1:100, ]
test_labels <- test_labels[ 1:100 ]

#training the model
rf <- randomForest(x = train_set, y = train_labels, xtest = test_set, ntree = 50)
rf
#> 
#> Call:
#>  randomForest(x = train_set, y = train_labels, xtest = test_set,      ntree = 50) 
#>                Type of random forest: classification
#>                      Number of trees: 50
#> No. of variables tried at each split: 28
#> 
#>         OOB estimate of  error rate: 11.6%
#> Confusion matrix:
#>    0   1  2  3  4  5  6   7  8  9 class.error
#> 0 96   0  0  0  0  0  1   0  0  0  0.01030928
#> 1  0 112  1  1  0  1  0   0  0  1  0.03448276
#> 2  2   6 82  0  2  0  1   4  2  0  0.17171717
#> 3  0   1  2 78  2  5  1   1  2  1  0.16129032
#> 4  0   0  1  0 94  1  2   1  1  5  0.10476190
#> 5  0   0  1  8  3 77  1   0  0  2  0.16304348
#> 6  1   0  1  0  2  2 86   1  1  0  0.08510638
#> 7  0   3  3  2  4  0  0 102  0  3  0.12820513
#> 8  0   1  1  3  1  6  1   1 71  2  0.18390805
#> 9  1   0  0  1  4  1  1   5  1 86  0.14000000

#1- error rate
#represents the accuracy 
1 - mean(rf$err.rate)
#> [1] 0.8012579


#error rate for every digit
err_df <- as.data.frame(rf$err.rate)
mymeans <- err_df %>%
  select(-"OOB") %>%
  colMeans()

# I build a data.frame containing the indexes and the means
toplot <- data.frame( index = seq_len( length( mymeans ) ) - 1,
                      col_means = mymeans )

# this is to plot via ggplot2
ggplot( toplot, aes( x = index, y = col_means ) ) +
  geom_line() +
  geom_point() + 
  scale_x_continuous(breaks = seq_len( length( mymeans ) ) - 1 )

reprex package (v0.3.0)

创建于 2021-02-16