使用 c/c++ 或矢量化加速条件 R 循环

Speed up conditional R loop with c/c++ or vectorization

我需要创建一个矩阵,其中包含以数据集中其他变量为条件的协变量值。这是我当前解决方案的一个独立示例

library(dplyr)
library(survival)
library(microbenchmark)
data(heart, package = "survival")
data <- heart

# Number of unique subjects
n.sub <- data %>% group_by(id) %>% n_groups()
# Unique failure times
fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
# Number of unique failure times
n.fail.time <- length(fail.time)

# Pre-fill matrix. Will be filled with covariate values.
mat <- matrix(NA_real_, nrow = n.sub, ncol = n.fail.time)
# Run loop
for(i in 1:n.sub) { # Number of subjects
  data.subject <- data[data$id == i, ] # subsetting here provides nice speed-up
  for(j in 1:n.fail.time) { # Number of failure times.
    value <- subset(data.subject, (start < fail.time[j]) & (stop >= fail.time[j]), select = transplant, drop = TRUE)
    if(length(value) == 0) { # An early event or censor will return empty value. Assign to zero.
      mat[i, j] <- 0
    }
    else {
      mat[i, j] <- value # True value
    }
  }
}

这对于具有数千个观测值的数据集来说太慢了。我不知道如何用 R 代码最好地对其进行矢量化,而且我对 c/c++ 的了解不足以利用 Rcpp。如何使用这些(或其他)选项之一加速此示例?

看起来 timereg package 中的 src/aalen.c 文件可能有一个与我的问题类似的 c 解决方案。查看带有 if ((start[c]<time) && (stop[c]>=time)) 的行周围的代码。虽然这可能只是我对 c/programming 表现的无知。

我在转向 C++ 以加速另一个问题时面临着类似的选择,但我最终转向了已经在 C++ 中有效实现并使用它们的 R 包。在这里,你想要的是一个名为 data.table 的包。

如果您是 R 的新手,这可能很难理解,但是通过小插图 here 可以为 data.table 包提供很好的文档。为了感受下面发生的事情,您可能会通过逐步执行我测试过的简化数据集的代码(请参阅答案的底部)并在对象更改值时监视它们来获得洞察力。速度提升的关键是使用data.table的快速赋值方法,并且只执行向量化操作。

我的解决方案如下。请注意,我不确定您是否需要 0、1、2 值,但我很乐意更改代码以生成 0、1,如果这是您想要的。

require(data.table)

dataDT <- data.table(data[, c("id", "start", "stop", "transplant")])
# add a serial number for each id
dataDT[, idObs := 1:length(start), by = id ]
# needed because transplant is a factor in the heart dataset
dataDT[, transplant := as.integer(transplant)]

# create a "long" format data.table of subjects, observation number, and start/stop times
matDT <- data.table(subject = rep(1:n.sub, each = n.fail.time * max(dataDT$idObs)),
                    idObs = rep(1:max(dataDT$idObs), max(dataDT$idObs), n.sub * max(dataDT$idObs)),
                    fail.time = rep(fail.time, each = max(dataDT$idObs)))

# merge in start and stop times
setkey(matDT, subject, idObs)
setkey(dataDT, id, idObs)
matDT <- dataDT[matDT]

# eliminate missings (for which no 2nd observation took place)
matDT <- matDT[!is.na(transplant)]

# this replicates the "value" assignment in the loop
matDT[, value := transplant * ((start < fail.time) & (stop >= fail.time))]

# sum on the ids by fail time
matDT2 <- matDT[, list(matVal = sum(value)), by = list(id, fail.time)]

# convert to a matrix
mat2 <- matrix(matDT2$matVal, ncol = ncol(mat), byrow = TRUE, dimnames = list(1:n.sub, fail.time))

根据 microbenchmark(),这比您的代码快很多倍,其中第一种方法是问题中的代码:

        min         lq       mean     median         uq       max neval
 310.503535 339.364159 396.287178 354.292829 406.937216 762.28838   100
   7.113083   7.420517   9.436973   7.788479   9.426443  32.50355   100

为了显示输出,我在您的 data 对象的前六行上进行了测试。这提供了一个很好的例子,因为第三和第四位患者 (id = 3, 4) 在移植前后各有两个观察结果。

data <- heart[1:6, ]

然后我将行和列标签添加到您的 mat 对象:

colnames(mat) <- fail.time
rownames(mat) <- 1:n.sub
mat
##   6 16 39 50
## 1 1  1  1  1
## 2 1  0  0  0
## 3 2  2  0  0
## 4 1  1  2  0

在这里你可以看到新的 mat2 是相同的:

mat2
##   6 16 39 50
## 1 1  1  1  1
## 2 1  0  0  0
## 3 2  2  0  0
## 4 1  1  2  0
all.equal(mat, mat2)
## [1] TRUE

这是@KenBenoit 解决方案的 dplyr 版本(参见 dplyr.matrix 函数)。下面是测试所有三种方法的代码。

library(dplyr)
library(data.table)
library(survival)
library(microbenchmark)
data(heart, package = "survival")
data <- heart

old.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  # Pre-fill matrix. Will be filled with covariate values.
  mat <- matrix(NA_real_, nrow = n.subjects, ncol = n.fail.time)
  # Run loop
  for(i in 1:n.subjects) { # Number of subjects
    data.subject <- data[data$id == i, ] # subsetting here provides nice speed-up
    for(j in 1:n.fail.time) { # Number of failure times.
      value <- subset(data.subject, (start < fail.time[j]) & (stop >= fail.time[j]), select = transplant, drop = TRUE)
      if(length(value) == 0) { # An early event or censor will return empty value. Assign to zero.
        mat[i, j] <- 0
      }
      else {
        mat[i, j] <- value # True value
      }
    }
  }
  mat
}

dplyr.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  # add a serial number for each id
  data <- data %>% group_by(id) %>% mutate(id.serial = 1:length(start))
  # needed because transplant is a factor in the heart dataset
  data$transplant <- as.integer(data$transplant)
  # create a "long" format data.frame of subjects, observation number, and start/stop times
  data.long <- data.frame(
    id = rep(1:n.subjects, each = n.fail.time * max(data$id.serial)),
    id.serial = rep(1:max(data$id.serial), max(data$id.serial), n.subjects * max(data$id.serial)),
    fail.time = rep(fail.time, each = max(data$id.serial))
  )
  # merge in start and stop times
  data.merge <- left_join(data.long, data[, c("start", "stop", "transplant", "id", "id.serial")], by = c("id", "id.serial"))
  # eliminate missings (for which no 2nd observation took place)
  data.merge <- na.omit(data.merge)
  # this replicates the "value" assignment in the loop
  data.merge <- data.merge %>% mutate(value = transplant * ((start < fail.time) & (stop >= fail.time)))
  # sum on the ids by fail time
  data.merge <- data.merge %>% group_by(id, fail.time) %>% summarise(value = sum(value))
  # convert to a matrix
  data.matrix <- matrix(data.merge$value, ncol = n.fail.time, byrow = TRUE, dimnames = list(1:n.subjects, fail.time))
  data.matrix
}

data.table.matrix <- function(data) {
  # Number of unique subjects
  n.subjects <- data %>% group_by(id) %>% n_groups()
  # Unique failure times
  fail.time <- data %>% filter(event == 1) %>% distinct(stop) %>% arrange(stop) %>% .$stop
  # Number of unique failure times
  n.fail.time <- length(fail.time)

  dataDT <- data.table(data[, c("id", "start", "stop", "transplant")])
  # add a serial number for each id
  dataDT[, idObs := 1:length(start), by = id ]
  # needed because transplant is a factor in the heart dataset
  dataDT[, transplant := as.integer(transplant)]
  # create a "long" format data.table of subjects, observation number, and start/stop times
  matDT <- data.table(subject = rep(1:n.subjects, each = n.fail.time * max(dataDT$idObs)),
                      idObs = rep(1:max(dataDT$idObs), max(dataDT$idObs), n.subjects * max(dataDT$idObs)),
                      fail.time = rep(fail.time, each = max(dataDT$idObs)))
  # merge in start and stop times
  setkey(matDT, subject, idObs)
  setkey(dataDT, id, idObs)
  matDT <- dataDT[matDT]
  # eliminate missings (for which no 2nd observation took place)
  matDT <- matDT[!is.na(transplant)]
  # this replicates the "value" assignment in the loop
  matDT[, value := transplant * ((start < fail.time) & (stop >= fail.time))]
  # sum on the ids by fail time
  matDT2 <- matDT[, list(matVal = sum(value)), by = list(id, fail.time)]
  # convert to a matrix
  mat2 <- matrix(matDT2$matVal, ncol = n.fail.time, byrow = TRUE, dimnames = list(1:n.subjects, fail.time))
  mat2
}

all(dplyr.matrix(data) == old.matrix(data))
all(dplyr.matrix(data) == data.table.matrix(data))

microbenchmark(
  old.matrix(data),
  dplyr.matrix(data),
  data.table.matrix(data),
  times = 50
)

微基准测试的输出:

Unit: milliseconds
                    expr        min         lq      mean    median        uq       max neval cld
        old.matrix(data) 325.949687 328.102482 333.20923 329.39368 331.28305 373.44774    50   c
      dplyr.matrix(data)  17.586146  18.317833  20.04662  18.95724  19.62431  60.15858    50  b 
 data.table.matrix(data)   9.464045   9.892281  10.72819  10.29394  11.44812  12.67738    50 a  

以上结果对应于包含大约 100 个观测值的数据集。当我在一个包含大约 1000 个观察值的数据集上对此进行测试时,data.table 开始越来越远。

Unit: milliseconds
                    expr        min         lq       mean     median        uq        max neval cld
        old.matrix(data) 13095.7836 13114.1858 13162.5019 13134.0735 13150.217 13318.2496     5   c
      dplyr.matrix(data)  1067.1942  1075.5291  1149.0789  1166.8951  1197.998  1237.7787     5  b 
 data.table.matrix(data)   104.5133   155.2074   159.6794   159.6364   166.764   212.2758     5 a  

data.table 暂时获胜。