R - 查找两个字符串中公共元素的交集

R - Find intersect of common elements in two character strings

我正在寻找计算两个字符串中共同元素的最快方法。

字符串中的元素由|分隔。

模拟数据:

library(data.table)
dt <- data.table(input1 = c("A|B", "C|D|", "R|S|T", "A|B"),
                 input2 = c("A|B|C|D|E|F", "C|D|E|F|G", "R|S|T", "X|Y|Z"))

统计字符串中共有的元素并创建dt$outcome.

dt <- transform(dt, var1 = I(strsplit(as.character(input1), "\|")))
dt <- transform(dt, var2 = I(strsplit(as.character(input2), "\|")))
dt <- transform(dt, outcome = mapply(function(x, y) sum(x%in%y),
                                 var1, var2))

结果:

> dt
   input1      input2  var1        var2 outcome
1:    A|B A|B|C|D|E|F   A,B A,B,C,D,E,F       2
2:   C|D|   C|D|E|F|G   C,D   C,D,E,F,G       2
3:  R|S|T       R|S|T R,S,T       R,S,T       3
4:    A|B       X|Y|Z   A,B       X,Y,Z       0

这个例子效果很好,但真实数据有数千个 input1input2 的元素,并且有超过 200,000 行。因此,当前代码运行了几天,无法投入生产。

我们怎样才能加快速度?

dt$var1dt$var2 不是必需的输出,可以省略。

dt[, outcome:= lengths(str_extract_all(input2, sub('[|]$', '',input1)))][]
   input1      input2 outcome
1:    A|B A|B|C|D|E|F       2
2:   C|D|   C|D|E|F|G       2
3:  R|S|T       R|S|T       3
4:    A|B       X|Y|Z       0

您可以通过使用 C++、C 或 Fortran 编写代码来加快该过程。让我们看看 C++ 代码的样子:

Rcpp::cppFunction('
  std::vector<int> count_intersect(std::vector<std::string> vec1,
               std::vector<std::string> vec2, char split){
  auto  string_split = [=](std::string x) {
    std::vector<std::string> vec;
    std::string sub_string;
    for(auto i: x){
      if(i == split) {
        vec.push_back(sub_string);
        sub_string = "";
      }
      else sub_string+=i;
    }
    if(sub_string.size() > 0)vec.push_back(sub_string);
    return  vec;
  };
  
  auto count = [=](std::string input1, std::string input2){
    std::vector<std::string> in1 = string_split(input1);
    std::vector<std::string> in2 = string_split(input2);
    int total = 0;
    for (auto i: in1) 
      if(std::find(in2.begin(), in2.end(), i) != in2.end()) total += 1;
    return total;
  };
  std::size_t len1 = vec1.size();
  std::vector<int> result(len1);
  for (std::size_t i = 0; i<len1; i++)
    result[i] = count(vec1[i], vec2[i]);
  return result;
}')

 dt[, outcome:=count_intersect(input1, input2, "|")][]
       input1      input2 outcome
    1:    A|B A|B|C|D|E|F       2
    2:   C|D|   C|D|E|F|G       2
    3:  R|S|T       R|S|T       3
    4:    A|B       X|Y|Z       0
    

做基准测试:使用非常大的数据,即 200,000 行:

bigdt <- mosaic::sample(dt, 200000, TRUE)[,1:2]
inputs <- c("input1", "input2")
vars <- c("var1", "var2")

bench::mark(OP = {
  bigdt <- transform(bigdt, var1 = I(strsplit(as.character(input1), "\|")))
  bigdt <- transform(bigdt, var2 = I(strsplit(as.character(input2), "\|")))
  bigdt <- transform(bigdt, outcome = mapply(function(x, y) sum(x%in%y), var1, var2))
},
r2evans = {
  bigdt[, (vars) := lapply(.SD, strsplit, "|", fixed = TRUE), .SDcols = inputs
  ][, outcome := mapply(function(x, y) sum(x %in% y), var1, var2)]
},
r2evans2 = {bigdt[, outcome := mapply(function(x, y) sum(x %in% y), 
                          strsplit(input1, "|", fixed = TRUE), 
                          strsplit(input2, "|", fixed = TRUE)) ]},
onyambu = {
  bigdt[, outcome:= lengths(stringr::str_extract_all(input2, sub('[|]$', '',input1)))]
},
onyambuCpp = bigdt[, outcome:=count_intersect(input1, input2, "|")],
 relative = TRUE
)



 A tibble: 5 x 13
  expression   min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result                     memory     time       gc      
  <bch:expr> <dbl>  <dbl>     <dbl>     <dbl>    <dbl> <int> <dbl>   <bch:tm> <list>                     <list>     <list>     <list>  
1 OP         12.4   12.1       1        30.9       Inf     1     6      1.66s <data.table [200,000 x 5]> <Rprofmem> <bench_tm> <tibble>
2 r2evans     4.77   4.66      2.60      5.72      Inf     1     3   641.39ms <data.table [200,000 x 5]> <Rprofmem> <bench_tm> <tibble>
3 r2evans2    6.08   5.94      2.04      5.70      Inf     1     5    817.4ms <data.table [200,000 x 5]> <Rprofmem> <bench_tm> <tibble>
4 onyambu     7.36   7.20      1.68      2.47      NaN     1     0   990.19ms <data.table [200,000 x 5]> <Rprofmem> <bench_tm> <tibble>
5 onyambuCpp  1      1        12.1       1         NaN     4     0   549.54ms <data.table [200,000 x 5]> <Rprofmem> <bench_tm> <tibble>

注意单位是相对的,CPP 至少比下一个方法快 4*。

两件事应该有所帮助:

  1. 使用 data.table 的引用语义,专用于 efficiency/speed。您对 transform 的使用大大降低了您的速度:

    bench::mark(
      base = { bigdt <- transform(bigdt, var1 = I(strsplit(as.character(input1), "\|"))); },
      datatable = { bigdt[, var1 := strsplit(input1, "\|")]; }
    )
    # # A tibble: 2 x 13
    #   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result                   memory   time    gc    
    #   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>                   <list>   <list>  <list>
    # 1 base         2.69ms   3.44ms     271.      299KB     0      136     0      501ms <data.table [4,000 x 3]> <Rprofm~ <bench~ <tibb~
    # 2 datatable   11.33ms  13.53ms      68.0     110KB     2.27    30     1      441ms <data.table [4,000 x 3]> <Rprofm~ <bench~ <tibb~
    
  2. strsplit(., "\|") 转移到 strsplit(., "|", fixed = TRUE) 以减少正则表达式的开销。

    bench::mark(
      regex = strsplit(bigdt$input1, "\|"), 
      fixed = strsplit(bigdt$input1, "|", fixed = TRUE)
    )
    # # A tibble: 2 x 13
    #   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result         memory             time    gc    
    #   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>         <list>             <list>  <list>
    # 1 regex        1.94ms   2.12ms      419.    31.3KB     0      210     0      501ms <list [4,000]> <Rprofmem [1 x 3]> <bench~ <tibb~
    # 2 fixed       219.7us 246.95us     3442.    31.3KB     2.21  1554     1      452ms <list [4,000]> <Rprofmem [1 x 3]> <bench~ <tibb~
    

(由于许多列通常有不同的单位,我倾向于将 `itr/sec` 视为相对性能的合理衡量标准。)

结合这两种技术(包括 onyambu 的出色推荐),我们看到了显着的改进:

inputs <- c("input1", "input2")
vars <- c("var1", "var2")
bench::mark(OP = {
  bigdt <- transform(bigdt, var1 = I(strsplit(as.character(input1), "\|")))
  bigdt <- transform(bigdt, var2 = I(strsplit(as.character(input2), "\|")))
  bigdt <- transform(bigdt, outcome = mapply(function(x, y) sum(x%in%y), var1, var2))
},
r2evans = {
  bigdt[, (vars) := lapply(.SD, strsplit, "|", fixed = TRUE), .SDcols = inputs
       ][, outcome := mapply(function(x, y) sum(x %in% y), var1, var2)]
},
onyambu = {
  bigdt[, outcome:= lengths(stringr::str_extract_all(input2, sub('[|]$', '',input1)))]
}
)
# # A tibble: 3 x 13
#   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result                   memory   time    gc    
#   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>                   <list>   <list>  <list>
# 1 OP           18.8ms  20.95ms      43.7    1.21MB     2.30    19     1      435ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~
# 2 r2evans       7.5ms   8.42ms     105.   238.19KB     2.28    46     1      439ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~
# 3 onyambu      10.9ms  11.87ms      80.8  130.36KB     0       41     0      508ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~

这会持续扩展。如果我使用同样更大的 table,也许

bench::mark(...)
# # A tibble: 3 x 13
#   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result                     memory   time   gc   
#   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>                     <list>   <list> <lis>
# 1 OP            2.71s    2.71s     0.369   96.56MB     2.21     1     6      2.71s <data.table [400,000 x 5]> <Rprofm~ <benc~ <tib~
# 2 r2evans       1.38s    1.38s     0.723    17.8MB     2.17     1     3      1.38s <data.table [400,000 x 5]> <Rprofm~ <benc~ <tib~
# 3 onyambu       1.53s    1.53s     0.652    7.66MB     0        1     0      1.53s <data.table [400,000 x 5]> <Rprofm~ <benc~ <tib~

虽然只有一次迭代,但两个建议的答案都比 base-case 有显着的速度改进。

如果我们将 onyambu 的选择调整为 保存中间 var1var2 值,我们可以改进更多,其中:

# r2evans_2
bigdt[, outcome := mapply(function(x, y) sum(x %in% y), 
                          strsplit(input1, "|", fixed = TRUE), 
                          strsplit(input2, "|", fixed = TRUE)) ]
bench::mark(...)
# # A tibble: 4 x 13
#   expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result                   memory   time    gc    
#   <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list>                   <list>   <list>  <list>
# 1 OP          18.27ms  18.85ms      52.7    1.21MB   190.       5    18     94.9ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~
# 2 r2evans      7.28ms   8.18ms     123.   241.09KB   133.      24    26    195.7ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~
# 3 r2evans_2    6.61ms   7.56ms     134.   205.57KB   105.      33    26      247ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~
# 4 onyambu      10.7ms  12.21ms      82.8  110.88KB     2.02    41     1    495.2ms <data.table [4,000 x 5]> <Rprofm~ <bench~ <tibb~

像这样 code-optimization 问题的技巧是从大问题减少到小问题。我认为这是一个好的开始。如果您需要更快的速度,您可能需要转换为编译语言或其他语言,我不知道(副手)这可以如何改进。


数据,大于你的 4 行:

bigdt <- rbindlist(replicate(1000, dt, simplify=FALSE))
biggerdt <- rbindlist(replicate(100000, dt, simplify=FALSE))