如何使用存储在 data.table 或 tibble 中的线性模型添加预测列?

How to add a prediction column with linear models stored in data.table or tibble?

我将多个线性模型存储在一个 table 中。现在我想使用 reach 行中的模型来预测一个 y 值,给定相应行中的一个 x 值。

困难是由 data.table 和 tidyverse 在 table 中提取模型的方式造成的。 predict.lm 函数需要 class "lm" 对象,在 class "list" 对象中。

library(data.table)

model1 <- lm( y~x, data= data.table( x=c(1,2,3,4) , y=c(1,2,1,2) ))
model2 <- lm( y~x, data= data.table( x=c(1,2,3,4) , y=c(1,2,3,3) ))

model_dt <- data.table( id = c(1,2), model = list(model1, model2), x = c(3,3))

现在 model_dt 包含线性模型和所需的 x 值。

逐行预测效果很好:

predict.lm( model_dt[1]$model[[1]], model_dt[1])  # yields 1.6
predict.lm( model_dt[2]$model[[1]], model_dt[2])  # yields 2.6

但是直接添加列会报错:

model_dt[, pred_y := predict.lm( model , x )]         # ERROR
model_dt[, pred_y := predict.lm( model , x ), by=id]  # ERROR

============================================= ===================

tidyverse 中的相同设置:

library(tidyverse)

model1 <- lm( y~x, data= tibble( x=c(1,2,3,4) , y=c(1,2,1,2) ))
model2 <- lm( y~x, data= tibble( x=c(1,2,3,4) , y=c(1,2,3,3) ))

model_dt <- tibble( id = c(1,2), model = list(model1, model2), x = c(3,3))

predict.lm( model_dt[1,]$model[[1]], model_dt[1,])  # yields 1.6
predict.lm( model_dt[2,]$model[[1]], model_dt[2,])  # yields 2.6

并且使用 mutate 添加变量失败:

model_dt %>% mutate( pred_y = predict.lm( model, x ) )  # ERROR

似乎一个原因是,table 中 "model" 列中的模型无法提取为 class "lm" 对象,但使用模型data.table 或 mutate 函数中的 [[1]] 始终引用第 1 行中的模型。

class( model_dt[1,]$model )      # results in class "list"
class( model_dt[1,]$model[[1]] ) # results in class "lm"

结果应该是 table 如下:

   id model x pred_y
1:  1  <lm> 3    1.6
2:  2  <lm> 3    2.6

我相信有一个简单的解决方案,并且会很高兴收到您的意见。 map() 或 lapply() 的可能解决方案也有同样的问题。非常感谢。

============================================= ========================

编辑:除了问题

之外,此问题还在 data.table 中寻求解决方案

对于tidyverse,我们使用map2循环遍历'model',对应的'x'值,将predict中的新数据作为data.frametibble

library(tidyverse)
model_dt %>% 
   mutate(pred_y = map2_dbl(model, x, ~ predict.lm(.x, tibble(x = .y))))
# A tibble: 2 x 4
#     id model      x pred_y
#   <dbl> <list> <dbl>  <dbl>
#1     1 <lm>       3   1.6 
#2     2 <lm>       3   2.60

或使用 data.table(对象)和 Map

model_dt[,  pred_y := unlist(Map(function(mod, y) 
          predict.lm(mod, data.frame(x = y)), model, x)), id][]
#   id model x pred_y
#1:  1  <lm> 3    1.6
#2:  2  <lm> 3    2.6