找到最长的滚动总和 under/equal 每个组的阈值

Find longest rolling sum under/equal a threshold for each group

我正在尝试找到最有效的方法来识别值低于特定阈值的向量的最长滚动总和。例如,如果我们有 1:10 和阈值 6,那么 3 就是我们最长的滚动总和。

我已经得到了我为此制作的以下代码和功能,但显然速度非常慢。想知道是否有更有效或已经实现的算法来识别总和低于阈值的最长运行。

library(dplyr)
library(zoo)

set.seed(1)

df <- data.frame(
  g = rep(letters[1:10], each = 100),
  x = runif(1000)
)

longest_rollsum <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- rollsum(x, i, na.pad = TRUE, align = "right")
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

df %>%
  group_by(g) %>%
  summarize(longest = longest_rollsum(x, 2))
#> # A tibble: 10 × 2
#>    g     longest
#>    <chr>   <dbl>
#>  1 a           7
#>  2 b           6
#>  3 c           9
#>  4 d           6
#>  5 e           6
#>  6 f           8
#>  7 g           6
#>  8 h           7
#>  9 i          11
#> 10 j           9

提高性能的一种方法是使用 zoo 包中的 roll_sum() function from the RcppRoll package instead of rollsum()

library(tidyverse)
library(zoo)
#> 
#> Attaching package: 'zoo'
#> The following objects are masked from 'package:base':
#> 
#>     as.Date, as.Date.numeric

set.seed(1)

df <- data.frame(
  g = rep(letters[1:10], each = 100),
  x = runif(1000)
)

longest_rollsum <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- zoo::rollsum(x, i, na.pad = TRUE, align = "right")
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

longest_rollsum_rcpp <- function(x, threshold) {
  for (i in 1:length(x)) {
    rs <- RcppRoll::roll_sum(x, i, align = "right", fill=NA)
    if (!any(rs <= threshold, na.rm = TRUE)) {
      return(i - 1)
    }
  }
  return(i)
}

rollsum <- function(df){
  df %>%
    group_by(g) %>%
    summarize(longest = longest_rollsum(x, 2))
}

rollsum_rcpp <- function(df){
  df %>%
    group_by(g) %>%
    summarize(longest = longest_rollsum_rcpp(x, 2))
}

library(microbenchmark)
res <- microbenchmark(rollsum(df), rollsum_rcpp(df))
autoplot(res)
#> Coordinate system already present. Adding new coordinate system, which will replace the existing one.

all_equal(rollsum(df), rollsum_rcpp(df))
#> [1] TRUE

reprex package (v2.0.1)

于 2021-12-17 创建

您可以编写一个更快的 R 函数:

longest_rollsum_R <- function(x, threshold) {
  R_rollsum <- function(k){
    for(i in seq(length(x) - k))
      if(sum(x[i:(i+k)]) < threshold) return(FALSE)
    TRUE
  }
  for (i in 1:length(x)) if(R_rollsum(i)) return(i)    
}

现在与上面的比较,

Unit: milliseconds
             expr       min        lq      mean    median        uq       max neval
      rollsum(df) 23.440017 25.407785 29.007619 27.906883 31.014434 45.516888   100
 rollsum_rcpp(df)  4.046400  4.499688  5.253406  4.734718  5.596438 14.079618   100
   rollsum_R1(df)  3.798058  4.194639  5.100568  4.710468  5.280267 14.749520   100

变化似乎不大。但是当我们改变阈值时它是相当大的:

Unit: milliseconds
                 expr        min         lq       mean    median         uq       max neval
      rollsum(df, 20) 111.336885 130.055676 142.683567 138.11347 147.231358 306.06438   100
 rollsum_rcpp(df, 20)  11.640328  13.170309  15.166030  14.03039  16.060333  31.23998   100
   rollsum_R1(df, 20)   5.993384   7.128607   8.125868   7.54488   8.140206  19.86842   100

您也可以在 Rcpp 中编写自己的代码来解决短路问题,这比目前给出的两种方法更快。

在您的工作目录中将以下内容保存为 longest_rollsum.cpp:

#include <Rcpp.h>
using namespace Rcpp;


// [[Rcpp::export]]

int longest_rollsum_C(NumericVector x, double threshold){
  auto rollsum_t = [&x, threshold](int k) {
    for (int i = 0; i< x.length() - k; i++) 
      if(sum(x[seq(i, i+k)]) < threshold) return false;
    return true;
   };
  for (int i = 0; i<x.length(); i++) if(rollsum_t(i)) return i;
  return 0;
}

在 R 中:获取上述文件

Rcpp::sourceCpp("longest_rollsum.cpp")
rollsum_R <- function(df){
  df %>% 
     group_by(g) %>%
    summarise(longest = longest_rollsum_C(x, 2))
}

microbenchmark::microbenchmark(rollsum(df), rollsum_rcpp(df), rollsum_R(df))
Unit: milliseconds
             expr       min        lq      mean    median        uq      max neval
      rollsum(df) 24.052665 25.018864 26.985276 25.453187 27.479305 37.49629   100
 rollsum_rcpp(df)  4.077397  4.352724  4.755942  4.572804  4.902468 13.10230   100
    rollsum_R(df)  2.271907  2.529000  2.871154  2.714801  2.955849 10.62107   100