Rcpp 函数用于查找中位数,给定一个值及其频率的向量

Rcpp function to find the median, given a vector of values and their frequencies

我正在编写一个函数来查找一组值的中位数。数据显示为唯一值的向量(称它们为 'values')和它们的频率向量('freqs')。通常频率非常高,因此粘贴它们会占用过多的内存。我有一个缓慢的 R 实现,它是我代码中的主要瓶颈,所以我正在编写一个自定义 Rcpp 函数以用于 R/Bioconductor 包。 Bioconductor 的网站建议不要使用 C++11,所以这对我来说是个问题。

我的问题在于尝试根据值的顺序将两个向量排序在一起。在 R 中,我们可以只使用 order() 函数。我似乎无法让它工作,尽管遵循了关于这个问题的建议:C++ sorting and keeping track of indexes

以下几行是问题所在:

   // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

这里是完整的功能,任何人都可以感兴趣。任何进一步的提示将不胜感激:

    #include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
    int len = freqs.size();
    if (any(freqs!=0)){
        int med = 0;
        return med;
    }
    // filter out the zeros pre-sorting
    IntegerVector non_zeros;
    for (int i = 0; i < len; i++){
        if(freqs[i] != 0){
            non_zeros.push_back(i);
        }
    }
    freqs = freqs[non_zeros];
    values = values[non_zeros];
    // find the order of values
    // create integer vector of indices
    IntegerVector idx(len);
    for (int i = 0; i < len; ++i) idx[i] = i;

    // sort vector based on order of values
 IntegerVector idx_ord = std::sort(idx.begin(), idx.end(),
    bool (int i1, int i2) {return values[i1] < values[i2];});

    //apply to freqs and values
    freqs = freqs[idx_ord];
    values=values[idx_ord];
    IntegerVector cum_freqs(len);
    cum_freqs[0] = freqs[0];
    for (int i = 1; i < len; ++i) cum_freqs[i] = freqs[i] + cum_freqs[i-1];
    int total_freqs = cum_freqs[len-1];
    // split into odd and even frequencies and calculate the median
    if (total_freqs % 2 == 1) {
        int med_ind = (total_freqs + 1)/2 - 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind){
            i++;
        }
        double ret = values[i];
        return ret;
    } else {
        int med_ind_1 = total_freqs/2 - 1; // C++ indexes from 0
        int med_ind_2 = med_ind_1 + 1; // C++ indexes from 0
        int i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_1){
            i++;
        }
        double ret_1 = values[i];
        i = 0;
        while ((i < len) && cum_freqs[i] < med_ind_2){
            i++;
        }
        double ret_2 = values[i];
        double ret = (ret_1 + ret_2)/2;
        return ret;
    }
}

对于使用 RUnit 测试框架的任何人,这里有一些基本的单元测试:

test_median_freq <- function(){
    checkEquals(median_freq(1:10,1:10),7)
    checkEquals(median_freq(1:10,rep(1,10)),5.5)
    checkEquals(median_freq(2:6,c(1,2,1,45,2)),5)
}

谢谢!

我实际上会将值和频率组合成一个 std::pair<double, int>,然后用 std::sort 对它们进行排序;通过这种方式,您始终可以将一个值及其频率保持在一起。这使您能够编写更简洁的代码,因为没有一组额外的索引浮动:

#include <Rcpp.h>
using namespace Rcpp;

// [[Rcpp::export]]
double median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector<std::pair<double, int> > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
    freqSum += freqs[i];
  }
  std::sort(allDat.begin(), allDat.end());
  int accum = 0;
  for (int i=0; i < len; ++i) {
    accum += allDat[i].second;
    if (freqSum % 2 == 0) {
      if (accum > freqSum / 2) {
        return allDat[i].first;
      } else if (accum == freqSum / 2) {
        return (allDat[i].first + allDat[i+1].first) / 2;
      }
    } else {
      if (accum >= (freqSum+1)/2) {
        return allDat[i].first;
      }
    }
  }
  return NA_REAL;  // Should not be reached
}

在 R 中试用:

median_freq(1:10, 1:10)
# [1] 7
median_freq(1:10,rep(1,10))
# [1] 5.5
median_freq(2:6,c(1,2,1,45,2))
# [1] 5

我们还可以编写一个简单的 R 实现来确定我们从使用 Rcpp 中获得的效率增益:

med.freq.r <- function(values, freqs) {
  ord <- order(values)
  values <- values[ord]
  freqs <- freqs[ord]
  s <- sum(freqs)
  cs <- cumsum(freqs)
  idx <- min(which(cs >= s/2))
  if (s %% 2 == 0 && cs[idx] == s/2) {
    (values[idx] + values[idx+1]) / 2
  } else {
    values[idx]
  }
}
med.freq.r(1:10, 1:10)
# [1] 7
med.freq.r(1:10,rep(1,10))
# [1] 5.5
med.freq.r(2:6,c(1,2,1,45,2))
# [1] 5

为了进行基准测试,让我们看一组非常大的值:

set.seed(144)
values <- rnorm(1000000)
freqs <- sample(1:100, 1000000, replace=TRUE)
all.equal(median_freq(values, freqs), med.freq.r(values, freqs))
# [1] TRUE
library(microbenchmark)
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs))
# Unit: milliseconds
#                        expr      min       lq     mean   median       uq      max neval
#  median_freq(values, freqs) 128.5322 131.6095 146.8360 145.6389 159.6117 165.0306    10
#   med.freq.r(values, freqs) 715.2187 744.5709 776.0539 765.9178 817.7157 855.1898    10

对于 100 万个条目,Rcpp 解决方案比 R 解决方案快大约 5 倍;考虑到编译开销,只有当您处理非常大的向量或者这是一个经常重复的选项时,该性能才具有吸引力。

线性时间方法

通常我们知道如何在不排序的情况下计算中位数(有关详细信息,请查看 http://www.cc.gatech.edu/~mihail/medianCMU.pdf)。虽然该算法比排序和迭代更精细,但它可以产生显着的加速:

double fast_median_freq(NumericVector values, IntegerVector freqs) {
  const int len = freqs.size();
  std::vector<std::pair<double, int> > allDat;
  int freqSum = 0;
  for (int i=0; i < len; ++i) {
    allDat.push_back(std::pair<double, int>(values[i], freqs[i]));
    freqSum += freqs[i];
  }

  int target = freqSum / 2;
  int low = 0;
  int high = len-1;
  while (true) {
    // Random pivot; move to the end
    int rnd = low + (rand() % (high-low+1));
    std::swap(allDat[rnd], allDat[high]);

    // In-place pivot
    int highPos = low;  // Start of values higher than pivot
    int lowSum = 0;  // Sum of frequencies of elements below pivot
    for (int pos=low; pos < high; ++pos) {
      if (allDat[pos].first <= allDat[high].first) {
        lowSum += allDat[pos].second;
        std::swap(allDat[highPos], allDat[pos]);
        ++highPos;
      }
    }
    std::swap(allDat[highPos], allDat[high]);  // Move pivot to "highPos"

    // If we found the element then return; o/w recurse on proper side
    if (lowSum >= target) {
      // Recurse on lower elements
      high = highPos - 1;
    } else if (lowSum + allDat[highPos].second >= target) {
      // Return
      if (target < lowSum + allDat[highPos].second || freqSum % 2 == 1) {
        return allDat[highPos].first;
      } else {
        double nextHighest = std::min_element(allDat.begin() + highPos+1, allDat.begin() + len-1)->first;
        return (allDat[highPos].first + nextHighest) / 2;
      }
    } else {
      // Recurse on higher elements
      low = highPos + 1;
      target -= (lowSum + allDat[highPos].second);
    }
  }
}

基准测试:

all.equal(median_freq(values, freqs), fast_median_freq(values, freqs))
[1] TRUE
microbenchmark(median_freq(values, freqs), med.freq.r(values, freqs), fast_median_freq(values, freqs), times=10)
# Unit: milliseconds
#                             expr       min        lq      mean    median        uq       max neval
#       median_freq(values, freqs) 119.57989 122.48622 130.47841 130.48811 132.75421 146.36136    10
#        med.freq.r(values, freqs) 665.72803 690.15016 708.05729 702.65885 731.83936 749.36834    10
#  fast_median_freq(values, freqs)  24.37572  29.39641  31.86144  31.77459  34.88418  36.81606    10

线性方法比先排序后迭代 Rcpp 解决方案提速 4 倍,比基本 R 解决方案提速 20 倍。