如何使用 R `tidymodels` 复制 plot.lda()
How can I replicate plot.lda() with of R `tidymodels`
我想使用 ggplot2
和 tidymodels
复制 plot.lda 打印方法。有没有优雅的方式获取剧情?
我想我可以伪造 augment()
函数,它没有 lda 方法,方法是使用 predict()
并将其绑定到原始数据。
这是一个使用基本 R 和 tidymodels
代码的示例:
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 + Lag2,
data = Smarket,
subset = train
)
plot(lda.fit)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 + Lag2,
data = Smarket
)
the_workflow<- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train) %>%
extract_fit_parsnip()
# now my attempt to do the plot
predictions <- predict(the_workflow_fit_lda_fit,
new_data = Smarket_train,
type = "raw"
)[[3]] %>%
as.vector()
bind_cols(Smarket_train, .fitted = predictions) %>%
ggplot(aes(x=.fitted)) +
geom_histogram(aes(y = stat(density)),binwidth = .5) +
scale_x_continuous(breaks = seq(-4, 4, by = 2))+
facet_grid(vars(Direction)) +
xlab("") +
ylab("Density")
一定有更好的方法来做到这一点....有什么想法?
您可以结合使用 extract_fit_*()
和 parsnip:::repair_call()
来完成此操作。 plot.lda()
方法使用了 LDA 拟合中的 $call
对象,我们需要对其进行调整,因为使用 tidymodels 的调用对象将不同于直接使用 lda()
。
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 + Lag2,
data = Smarket,
subset = train
)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 + Lag2,
data = Smarket
)
the_workflow <- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train)
拟合两个模型后,我们可以检查 $call
个对象,我们发现它们不同。
lda.fit$call
#> lda(formula = Direction ~ Lag1 + Lag2, data = Smarket, subset = train)
extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)
parsnip::repair_call()
函数会将data
替换为我们传入的数据。此外,我们会将数据的响应重命名为..y
以匹配调用。
the_workflow_fit_lda_fit %>%
extract_fit_parsnip() %>%
parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
extract_fit_engine() %>%
plot()
由 reprex package (v2.0.1)
于 2021-11-12 创建
我想使用 ggplot2
和 tidymodels
复制 plot.lda 打印方法。有没有优雅的方式获取剧情?
我想我可以伪造 augment()
函数,它没有 lda 方法,方法是使用 predict()
并将其绑定到原始数据。
这是一个使用基本 R 和 tidymodels
代码的示例:
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 + Lag2,
data = Smarket,
subset = train
)
plot(lda.fit)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 + Lag2,
data = Smarket
)
the_workflow<- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train) %>%
extract_fit_parsnip()
# now my attempt to do the plot
predictions <- predict(the_workflow_fit_lda_fit,
new_data = Smarket_train,
type = "raw"
)[[3]] %>%
as.vector()
bind_cols(Smarket_train, .fitted = predictions) %>%
ggplot(aes(x=.fitted)) +
geom_histogram(aes(y = stat(density)),binwidth = .5) +
scale_x_continuous(breaks = seq(-4, 4, by = 2))+
facet_grid(vars(Direction)) +
xlab("") +
ylab("Density")
一定有更好的方法来做到这一点....有什么想法?
您可以结合使用 extract_fit_*()
和 parsnip:::repair_call()
来完成此操作。 plot.lda()
方法使用了 LDA 拟合中的 $call
对象,我们需要对其进行调整,因为使用 tidymodels 的调用对象将不同于直接使用 lda()
。
library(ISLR2)
library(MASS)
# First base R
train <- Smarket$Year < 2005
lda.fit <-
lda(
Direction ~ Lag1 + Lag2,
data = Smarket,
subset = train
)
# Next tidymodels
library(tidyverse)
library(tidymodels)
library(discrim)
lda_spec <- discrim_linear() %>%
set_mode("classification") %>%
set_engine("MASS")
the_rec <- recipe(
Direction ~ Lag1 + Lag2,
data = Smarket
)
the_workflow <- workflow() %>%
add_recipe(the_rec) %>%
add_model(lda_spec)
Smarket_train <- Smarket %>%
filter(Year != 2005)
the_workflow_fit_lda_fit <-
fit(the_workflow, data = Smarket_train)
拟合两个模型后,我们可以检查 $call
个对象,我们发现它们不同。
lda.fit$call
#> lda(formula = Direction ~ Lag1 + Lag2, data = Smarket, subset = train)
extract_fit_engine(the_workflow_fit_lda_fit)$call
#> lda(formula = ..y ~ ., data = data)
parsnip::repair_call()
函数会将data
替换为我们传入的数据。此外,我们会将数据的响应重命名为..y
以匹配调用。
the_workflow_fit_lda_fit %>%
extract_fit_parsnip() %>%
parsnip::repair_call(rename(Smarket_train, ..y = Direction)) %>%
extract_fit_engine() %>%
plot()
由 reprex package (v2.0.1)
于 2021-11-12 创建