如何使用 R `tidymodels` 复制 plot.lda()

How can I replicate plot.lda() with of R `tidymodels`

我想使用 ggplot2tidymodels 复制 plot.lda 打印方法。有没有优雅的方式获取剧情?

我想我可以伪造 augment() 函数,它没有 lda 方法,方法是使用 predict() 并将其绑定到原始数​​据。

这是一个使用基本 R 和 tidymodels 代码的示例:

library(ISLR2)
library(MASS)

# First base R
train <- Smarket$Year < 2005

lda.fit <-
  lda(
    Direction ~ Lag1 + Lag2,
    data = Smarket,
    subset = train
  )

plot(lda.fit)


# Next tidymodels

library(tidyverse)
library(tidymodels)
library(discrim)

lda_spec <- discrim_linear() %>%
  set_mode("classification") %>%
  set_engine("MASS")

the_rec <- recipe(
  Direction ~ Lag1 + Lag2, 
  data = Smarket
)

the_workflow<- workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(lda_spec)

Smarket_train <- Smarket %>%
  filter(Year != 2005)

the_workflow_fit_lda_fit <- 
  fit(the_workflow, data = Smarket_train) %>% 
  extract_fit_parsnip()

# now my attempt to do the plot

predictions <- predict(the_workflow_fit_lda_fit, 
                 new_data = Smarket_train, 
                 type = "raw"
         )[[3]] %>% 
  as.vector()

bind_cols(Smarket_train, .fitted = predictions) %>% 
  ggplot(aes(x=.fitted)) +
  geom_histogram(aes(y = stat(density)),binwidth = .5) + 
  scale_x_continuous(breaks = seq(-4, 4, by = 2))+
  facet_grid(vars(Direction)) +
  xlab("") + 
  ylab("Density")

一定有更好的方法来做到这一点....有什么想法?

您可以结合使用 extract_fit_*()parsnip:::repair_call() 来完成此操作。 plot.lda() 方法使用了 LDA 拟合中的 $call 对象,我们需要对其进行调整,因为使用 tidymodels 的调用对象将不同于直接使用 lda()

library(ISLR2)
library(MASS)

# First base R
train <- Smarket$Year < 2005

lda.fit <-
  lda(
    Direction ~ Lag1 + Lag2,
    data = Smarket,
    subset = train
  )

# Next tidymodels

library(tidyverse)
library(tidymodels)
library(discrim)

lda_spec <- discrim_linear() %>%
  set_mode("classification") %>%
  set_engine("MASS")

the_rec <- recipe(
  Direction ~ Lag1 + Lag2, 
  data = Smarket
)

the_workflow <- workflow() %>% 
  add_recipe(the_rec) %>% 
  add_model(lda_spec)

Smarket_train <- Smarket %>%
  filter(Year != 2005)

the_workflow_fit_lda_fit <- 
  fit(the_workflow, data = Smarket_train)

拟合两个模型后,我们可以检查 $call 个对象,我们发现它们不同。

lda.fit$call
#> lda(formula = Direction ~ Lag1 + Lag2, data = Smarket, subset = train)

extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)

parsnip::repair_call()函数会将data替换为我们传入的数据。此外,我们会将数据的响应重命名为..y以匹配调用。

the_workflow_fit_lda_fit %>%
  extract_fit_parsnip() %>%
  parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
  extract_fit_engine() %>%
  plot()

reprex package (v2.0.1)

于 2021-11-12 创建