在 mgcViz 图中结合多个簇的 gam 平滑

Combine gam smooths of multiple clusters in an mgcViz plot

我对来自结合了不等长时间序列的基于模型的聚类代码的几个聚类进行了 gam 平滑处理,我​​想将它们与数据一起显示。

mgcViz 包为单个集群提供了出色的可视化效果,但我不知道如何将它们组合起来。也许是因为它旨在可视化多个效果而不是几个集群。尽管如此,它的能力非常接近我的需要,所以这里有一个可重现的例子(改编自https://mfasiolo.github.io/mgcViz/articles/mgcviz.html):

library(mgcViz)
n = 1e3
z = rnorm(n)
dat = data.frame(x = rep(z, times = 2),
                 y = rep(c(1,2), each = n) + c(sin(z), 0.5*z^2) + rnorm(2*n)/4,
                 g = factor(rep(1:2, each = n)))

b <- lapply(1:2, function(i, dat) gam(y ~ s(x), data = dat[dat$g == i, ]),
            dat = dat)

plot(getViz(b[[1]])) + l_points() + l_fitLine() + l_ciLine()   # First
plot(getViz(b[[2]])) + l_points() + l_fitLine() + l_ciLine()   # Second
plot(getViz(b))   # Third
ggplot(dat, aes(x, y, color = g)) + geom_point(pch = ".") + theme_bw() # Fourth

我想将前两个图合并为一个,就像在第三个图中部分完成的那样。将第三个图放入第四个显示的数据中就可以了。这也需要在第三个绘图拟合中进行不同的截距偏移。将 l_points() 添加到第三个图使其为空。

一个隐藏的约束是 gam 平滑是单独的列表组件(如上所示),因为它们实际上来自使用 mgcv 的不等长度和间距的时间序列片段的自定义聚类代码 bam 用于非常大的数据。因此绘图应该最好从 b 中获取所有信息,每个集群的列表 gam 结果。

不是mgcViz,但您可以简单地自己创建所需的输出并使用ggplot2:

library(mgcv)
library(ggplot2)
theme_set(theme_bw())

n = 1e3
z = rnorm(n)
dat = data.frame(
  z = rep(z, times = 2),
  y = rep(c(1,2), each = n) + c(sin(z), 0.5*z^2) + rnorm(2*n)/4,
  g = factor(rep(1:2, each = n)))

b <- gam(y ~ g + s(z, by = g), data = dat)

ndf <- expand.grid(z = seq(min(dat$z), max(dat$z), length.out=100), g = unique(dat$g))
ndf$pred <- predict(b, newdata = ndf, type = "response")

ggplot(ndf, aes(x = z, y = pred, col = g)) +
  geom_line() +
  geom_point(data = dat, aes(y = y))

reprex package (v0.3.0)

于 2021-01-28 创建

编辑

如果你想为每个组拟合单独的模型,可以这样做:

library(purrr)
ndf <- map_dfr(
  .x = unique(dat$g),
  .f = ~{
    mod_i <- gam(y ~ s(x), data = dat[dat$g == .x, ])
    ndf_i <- expand.grid(x = seq(min(dat$x), max(dat$x), length.out = 100))
    ndf_i$g <- .x
    ndf_i$pred <- predict(mod_i, newdata = ndf_i, type = "response")
    ndf_i
  })

ggplot(ndf, aes(x = x, y = pred, col = g)) +
  geom_line() +
  geom_point(data = dat, aes(y = y))