R将回归模型转换为函数

R convert regression model fit to a function

我想快速提取回归模型对函数的拟合。

所以我想从:

# generate some random data
set.seed(123)
x <- rnorm(n = 100, mean = 10, sd = 4)
z <- rnorm(n = 100, mean = -8, sd = 3)
y <- 9 * x - 10 * x ^ 2 + 5 * z + 10 + rnorm(n = 100, 0, 30)


df <- data.frame(x,y)
plot(df$x,df$y)

model1 <- lm(formula = y ~ x + I(x^2) + z, data = df)
summary(model1)

到描述我的拟合值的 model_function(x)

当然我可以用这样的方式手工完成:

model_function <- function(x, z, model) {
  fit <- coefficients(model)["(Intercept)"] + coefficients(model)["x"]*x + coefficients(model)["I(x^2)"]*x^2 + coefficients(model)["z"]*z
  return(fit)
}

fit <- model_function(df$x,df$z, model1)

我可以将其与实际拟合值进行比较,并且(有一些舍入误差)效果很好。

all(round(as.numeric(model1$fitted.values),5) == round(fit,5))

但这当然不是一个通用的解决方案(例如更多变量等)。

明确一点: 是否有一种简单的方法可以将拟合值关系提取为刚刚估计的系数的函数?

注意:我当然知道 predict 以及从新数据生成拟合值的能力 - 但我确实在寻找该基础函数。也许这可以通过 predict?

感谢任何帮助!

其中任何一个给出了拟合值:

fitted(model1)

predict(model1)

model.matrix(model1) %*% coef(model1)

y - resid(model1)

X <- model.matrix(model1); X %*% qr.solve(X, y)

X <- cbind(1, x, x^2, z); X %*% qr.solve(X, y)

其中任何一个给出了任何特定 x 和 z 的预测值:

cbind(1, x, x^2, z) %*% coef(model1)

predict(model1, list(x = x, z = z))

如果你想要一个实际的功能,你可以这样做:

get_func <- function(mod) {
  vars <- as.list(attr(mod$terms, "variables"))[-(1:2)]
  funcs <- lapply(vars, function(x) list(quote(`*`), 1, x))
  terms <- mapply(function(x, y) {x[[2]] <- y; as.call(x)}, funcs, mod$coefficients[-1],
         SIMPLIFY = FALSE)
  terms <- c(as.numeric(mod$coefficients[1]), terms)
  body <- Reduce(function(a, b) as.call(list(quote(`+`), a, b)), terms)
  vars <- setNames(lapply(seq_along(vars), function(x) NULL), sapply(vars, as.character))
  f <- as.function(c(do.call(alist, vars), body))
  formals(f) <- formals(f)[!grepl("\(", names(formals(f)))]
  f
}

允许:

my_func <- get_func(model1)

my_func
#> function (x = NULL, z = NULL) 
#> 48.6991866925322 + 3.31343108778127 * x + -9.77589420188036 * I(x^2) + 5.38229596972984 * z
<environment: 0x00000285a1982b48>

my_func(x = 1:10, z = 3)
#> [1]   58.38361   32.36936  -13.19668  -78.31451 -162.98413 -267.20553 
#> [7] -390.97872 -534.30371 -697.18048 -879.60903

 plot(1:10, my_func(x = 1:10, z = 3), type = "b")

目前,这不适用于交互项等,但应该适用于大多数简单的线性模型