将 R 代码转换为 C++ 以实现 Rcpp

Converting a R code into C++ for Rcpp implementation

我有一个用 R 编写的简单“for 循环”,我必须将其转换为 C++。以下是 R 代码的可重现示例:

# Parameters required 
a <- 1.8
b <- 1
time.dt <- 0.1
yp <- 40 
insp.int <- 7
ph <- 2000

dt <- seq(0,ph,time.dt) # Time sequence
MD.set <- c(seq(insp.int, ph, insp.int), ph) # Decision points to check and set next inspection date

# Initialization
cum.y <- rep(0,length = length(dt)) 
init.y <- 0
flag <- FALSE

# At each iteration, the following loop generates a gamma distributed random number and cum.y keeps taking cumulative sum
# The objective is to return a vector cum.y with a conditional cumulative sum of previous iteration 
# When dt[i] matches any values in MD.Set AND corresponding cum.y[i] is also >= yp it changes the flag to true (the last if)
# At the start of the loop it checks if dt[i] matches any values in MD.Set AND flag is also true. If yes, then cum.y is reset to 0. 

for (i in 2:length(dt)){
  if (dt[i] %in% MD.set && flag == TRUE){
    cum.y[i] <- 0
    init.y <- 0
    flag <- FALSE
    next
  } else {
    cum.y[i] <- init.y + rgamma(n = 1, shape = a*time.dt, scale = b)
    init.y <- cum.y[i]
    if (dt[i] %in% MD.set && cum.y[i] >= yp){
      flag <- TRUE
    }
  }
}

res <- cbind(dt, cum.y)

我之前没有使用 C++ 的经验,因此在尝试这样做时 运行 遇到了很多问题。我需要进行此转换只是为了能够在 R 的 Rcpp 包中使用它。因为代码在 R 中运行缓慢,特别是当 time.dt 变小时,我猜 C++ 会更快地完成这项工作。你能帮忙吗?

更新 2: 这是我在评论和回答的帮助下提出的转换建议。但是,我不确定 C++next 的等价物是什么。如果我使用 continue 它会继续执行其余代码(并执行 else 之后的代码。如果我使用 break 然后它会在条件为真后退出循环。

NumericVector cumy(double a, double b, double timedt, NumericVector dt, NumericVector MDSet, double yp){
  bool flag = false;
  int n = dt.size();
  double total = 0;
  NumericVector out(n);
  unordered_set<int> sampleSet(MDSet.begin(), MDSet.end());
  
  for (int i = 0; i < n; ++i){
    if (sampleSet.find(dt[i]) != sampleSet.end() && flag == true){
      out[i] = 0;
      total = 0;
      flag = false;
      continue;
    } else {
      out[i] = total + rgamma(1, a*timedt, b)[0];
      total = out[i];
      if (sampleSet.find(dt[i]) != sampleSet.end() && out[i] >= yp){
        flag = true;
      }
    }
  }
  return out;
}

您收到的错误仅仅是因为没有从 NumericVectorstd::unordered_set<int> 的自动转换。您可以通过以下方式解决此问题:

std::unordered_set<int> sampleSet( MDSet.begin(), MDSet.end() )

这将使用 MDSet 的开始和结束迭代器调用 unordered_set 的构造函数,这将使用所有值填充集合。

您的代码中还有另一个问题:

if (sampleSet.find(dt[i]) == sampleSet.begin())

仅当在 sampleSet 的第一个元素处找到 dt[i] 时才为真。根据您的 r 代码,我假设您只是检查值 dt[i] 是否在 sampleSet 内,在这种情况下,您需要:

if (sampleSet.find(dt[i]) != sampleSet.end())

在C++中,STL find方法一般return一个迭代器,当值为找到时,它return是结束迭代器,所以如果 find 的 return 值不是 end,则在集合中找到该值。