如何解决 R 中 for 循环的问题

How do fix issue with for loop in R

我正在为想要基于 AADTMAJ、L 和 Base_Past 预测值的人们编写一个程序包。该函数提供两个选项 1) 允许用户输入自己的回归系数,或 2) 为用户提供预定义的系数。但是,我无法正确使用 return() 。

输入数据

data=data.frame(Base_Past=c("HSM-RUR2U-KABCO",
                            "HSM-RUR2U-KABCO",
                            "HSM-RUR4-KABC",
                            "HSM-RUR4-KABCO"),
                AADTMAJ=c(100,100,100,100),
                L=c(1,1,1,1)
)

输入自定义回归系数

custom.spf=data.frame(Base_Past=c("HSM-RUR2U-KABCO","HSM-RUR2U-KABC"), a=c(-0.312,-0.19))

定义辅助函数

helper_function = function (data, Base_Past=FALSE, override=custom.spf){
  if (is.data.frame(override)){
    for (j in 1:nrow(override)){
      for (i in 1:nrow(data)){
        if(data[i, ]$Base_Past==override[j, ]$Base_Past){
          output=as.numeric(data[i, ]$AADTMAJ*data[i, ]$L*365*10^(-6)*exp(override[j, ]$a))
          return(output)} else{
            if(data[i, ]$Base_Past=="HSM-RUR4-KABCO") {a=-0.101}
            if(data[i, ]$Base_Past=="HSM-RUR4-KABC") {a=-0.143}
            output=as.numeric(data[i, ]$AADTMAJ*data[i, ]$L*365*10^(-6)*exp(a))
            return(output)
          } 
      }
    }
  }
  
  else if (!is.data.frame(override)){
    if(Base_Past=="HSM-RUR4-KABCO") {a=-0.101}
    if(Base_Past=="HSM-RUR4-KABC") {a=-0.143}
    output=as.numeric(data[i, ]$AADTMAJ*data[i, ]$L*365*10^(-6)*exp(a))
    return(output)
  }
}

运行

(data %>% dplyr::rowwise() %>% dplyr::mutate(predicted_value = helper_function(data = data, override=custom.spf)))[,4]



输出

# A tibble: 4 x 1
# Rowwise: 
  predicted_value
            <dbl>
1          0.0267
2          0.0267
3          0.0267
4          0.0267

备选

data %>% dplyr::mutate(predicted_value=dplyr::case_when(Base_Past =="HSM-RUR4-KABCO" ~AADTMAJ*L*365*10^(-6)*exp(-0.101),
                                                        Base_Past=="HSM-RUR4-KABC" ~AADTMAJ*L*365*10^(-6)*exp(-0.143),
                                                        Base_Past=="HSM-RUR2U-KABCO" ~AADTMAJ*L*365*10^(-6)*exp(-0.312),
                                                        Base_Past=="HSM-RUR2U-KABC" ~AADTMAJ*L*365*10^(-6)*exp(-0.190),
                                                        TRUE ~ NA_real_))

期望的输出

        Base_Past AADTMAJ L predicted_value
1 HSM-RUR2U-KABCO     100 1      0.02671733
2 HSM-RUR2U-KABCO     100 1      0.02671733
3   HSM-RUR4-KABC     100 1      0.03163652
4  HSM-RUR4-KABCO     100 1      0.03299356

该功能和您对它的使用有几个问题。自从我的第一批评论以来,在问题列表中值得注意:

  • 您在 rowwise 管道中调用它,然后传递 data=data,这意味着它 忽略 传入的数据管道,而不是看着整个事情。您可以改为使用 data=cur_data()(因为它在 mutate 内部,这是可行的,因为 cur_data()dplyr 定义,用于类似这种情况)。

  • 您的 helper_function 是 ill-defined,假设 custom.spf 已定义且可用。让一个函数依赖于未显式传递给它的外部变量的存在会使它变得脆弱并且很难进行故障排除。例如,如果 custom.spf 未在调用环境中定义,则此函数将 失败 object 'custom.spf' not found。相反,我认为你可以使用:

    helper_function <- function(..., override=NA) {
      if (isTRUE(is.na(override)) && exists("custom.spf")) {
        message("found 'custom.spf', using it as 'override'")
        override <- custom.spf
      }
      ...
    }
    

    我对此仍然不是很兴奋,但至少它不会 失败 太快,而且它所做的事情很冗长。

  • 如果以编程方式使用,1:nrow(.)可能有一点风险。也就是说,如果出于某种原因,其中一个输入有 0 行(也许 custom.spf 没有什么可以覆盖),那么 1:nrow(.) 在逻辑上应该什么也不做,而是对不存在的行进行两次迭代。即如果nrow(.)为0,则注意1:0returnsc(1, 0),这显然不是“什么都不做”。相反,使用 seq_len(nrow(.)) 作为 seq_len(0) returns integer(0),这正是我们想要的。

  • 没有理由在这里使用rowwise(),应该尽可能避免使用它。 (它做的非常好,当确实有必要时,它工作得很好。但是一次迭代一行的性能损失可能很大,尤其是对于较大的数据。)

通过学习 merge/join 方法可以简化您尝试做的一些事情。 merge/join 的两个非常好的参考资料是:How to join (merge) data frames (inner, outer, left, right), What's the difference between INNER JOIN, LEFT JOIN, RIGHT JOIN and FULL JOIN?.

此外,您的大部分工作似乎都是为方程式 a 分配一个合理的值。您的内部代码(寻找 "-KABCO""-KABC")看起来确实应该是另一个默认值框架。

这里有一个建议 helper_function,它稍微改变了一些东西。它采用强制参数 Base_PastAADTMAJL,然后零个或多个帧到 merge/join,以便在 a 中找到合适的值等式。

helper_function <- function(Base_Past, AADTMAJ, L, ...) {
  stopifnot(
    length(Base_Past) == length(AADTMAJ),
    length(Base_Past) == length(L)
  )
  defaults <- data.frame(Base_Past = c("HSM-RUR4-KABCO", "HSM-RUR4-KABC"), a = c(-0.101, -0.143))
  frames <- c(list(defaults), list(...))
  a <- rep(NA, length(Base_Past))
  tmpdat <- data.frame(row = seq_along(Base_Past), Base_Past = Base_Past, a = a)
  for (frame in frames) {
    tmpdat <- merge(tmpdat, frame, by = "Base_Past", suffixes = c("", ".y"),
                    all.x = TRUE, sort = FALSE)
    tmpdat$a <- ifelse(is.na(tmpdat$a), tmpdat$a.y, tmpdat$a)
    tmpdat$a.y <- NULL
  }
  tmpdat <- tmpdat[order(tmpdat$row),]
  AADTMAJ * L * 365 * 10^(-6) * exp(tmpdat$a)
}

前提是您在函数中查找 a 的“默认”值与在 override 变量中查找它们实际上是一样的。我本可以为您提供单个查找字典的 override= 参数,但有时使用 “一个或多个” 类型的参数很有用:也许您有超过一帧具有 a 的其他值,您可能希望同时使用它们。这将按照您对单个的期望工作,但如果您有多个,也许 custom.spfcustom.spf,这将工作(通过在调用时将所有它们添加到 L 参数之后)。

出于一些原因,我选择保留函数内部简单的基 R,没有什么特别重要的。可以 dplyr 化的部分在 for (frame in frames) 循环内。

data %>%
  mutate(a = helper_function(Base_Past, AADTMAJ, L, custom.spf))
#         Base_Past AADTMAJ L          a
# 1 HSM-RUR2U-KABCO     100 1 0.02671733
# 2 HSM-RUR2U-KABCO     100 1 0.02671733
# 3   HSM-RUR4-KABC     100 1 0.03163652
# 4  HSM-RUR4-KABCO     100 1 0.03299356

如果您愿意,该函数应该在分组(group_byrowwise)内干净地运行,但肯定没有必要按照您最初的要求进行操作。