为什么标准 R 中值函数比简单的 C++ 替代函数慢得多?

Why is standard R median function so much slower than a simple C++ alternative?

我在 C++ 中实现了以下中位数,并通过 RcppR 中使用了它:

// [[Rcpp::export]]
double median2(std::vector<double> x){
  double median;
  size_t size = x.size();
  sort(x.begin(), x.end());
  if (size  % 2 == 0){
      median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
  }
  else {
      median = x[size / 2];
  }
  return median;
}

如果我随后将性能与标准内置 R 中值函数进行比较,我通过 microbenchmark

得到以下结果
> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
       expr    min     lq     mean median     uq     max neval
  median(x) 25.469 26.990 34.96888 28.130 29.081 518.126   100
 median2(x)  1.140  1.521  2.47486  1.901  2.281  47.897   100

为什么标准中值函数这么慢?这不是我所期望的...

我不确定您指的 "standard" 实现是什么。

无论如何:如果有的话,作为标准库的一部分,当然不允许更改向量中元素的顺序(正如您的实现所做的那样),因此它肯定必须继续工作复印件。

创建此副本需要时间和 CPU(以及大量内存),这会影响 运行 时间。

正如@joran 所指出的,您的代码非常专业,一般来说,不太通用的函数、算法等...通常性能更高。看看median.default

median.default
# function (x, na.rm = FALSE) 
# {
#   if (is.factor(x) || is.data.frame(x)) 
#     stop("need numeric data")
#   if (length(names(x))) 
#     names(x) <- NULL
#   if (na.rm) 
#     x <- x[!is.na(x)]
#   else if (any(is.na(x))) 
#     return(x[FALSE][NA])
#   n <- length(x)
#   if (n == 0L) 
#     return(x[FALSE][NA])
#   half <- (n + 1L)%/%2L
#   if (n%%2L == 1L) 
#     sort(x, partial = half)[half]
#   else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }

有几种操作可以适应缺失值的可能性,这些操作肯定会影响函数的整体执行时间。由于您的函数不会复制此行为,因此它可以消除一堆计算,但因此不会为具有缺失值的向量提供相同的结果:

median(c(1, 2, NA))
#[1] NA

median2(c(1, 2, NA))
#[1] 2

其他一些因素可能没有 处理 NA 的影响,但值得指出:

  • median,连同它使用的一些函数,都是 S3 泛型,所以在方法调度上花费了少量时间
  • median 不仅仅适用于整数和数字向量;它还将处理 DatePOSIXt,可能还有一堆其他 类,并正确保留属性:

median(Sys.Date() + 0:4)
#[1] "2016-01-15"

median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"

编辑: 我应该提到下面的函数 将导致原始向量被排序 因为 NumericVectors 是代理对象。如果你想避免这种情况,你可以 Rcpp::clone 输入向量并对克隆进行操作,或者使用你的原始签名(带有 std::vector<double>),这在从 [= 的转换中隐含地需要一个副本26=] 到 std::vector

另请注意,使用 NumericVector 代替 std::vector<double> 可以节省更多时间:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
  std::size_t size = x.size();
  std::sort(x.begin(), x.end());
  if (size  % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
  return x[size / 2];
}

microbenchmark::microbenchmark(
  median(x),
  median2(x),
  cpp_med(x),
  times = 200L
)
# Unit: microseconds
#       expr    min      lq      mean  median      uq     max neval
#  median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810   200
# median2(x)  6.474  7.9665  13.90126 11.0570  14.844 151.817   200
# cpp_med(x)  5.737  7.4285  11.25318  9.0270  13.405  52.184   200

Yakk 在上面的评论中提出了一个重要观点——Jerry Coffin 也对此进行了详细阐述——关于进行完整排序的效率低下。这是使用 std::nth_element 重写的,以更大的向量为基准:

#include <Rcpp.h>

// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
  Rcpp::NumericVector x = Rcpp::clone(xx);
  std::size_t n = x.size() / 2;
  std::nth_element(x.begin(), x.begin() + n, x.end());

  if (x.size() % 2) return x[n]; 
  return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}

set.seed(123)
xx <- rnorm(10e5)

all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))

microbenchmark::microbenchmark(
  cpp_med2(xx), median2(xx), 
  median(xx), times = 200L
)
# Unit: milliseconds
#         expr      min       lq     mean   median       uq       max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161  33.92103   200
#  median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301   200
#   median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939   200

[这是对您实际提出的问题的回答,而不是扩展评论。]

甚至您的代码也可以进行重大改进。特别是,即使您只关心一两个元素,您也要对整个输入进行排序。

您可以使用 std::nth_element 而不是 std::sort 将其从 O(n log n) 更改为 O(n)。在元素数量为偶数的情况下,您通常希望使用 std::nth_element 来查找中间之前的元素,然后使用 std::min_element 来查找紧随其后的元素——但是 std::nth_element 还对输入项进行分区,因此 std::min_element 只需 运行 在 nth_element 之后中间以上的项上,而不是整个输入数组。也就是说,在nth_element之后,你会得到这样的情况:

std::nth_element的复杂度是"linear on average",(当然)std::min_element也是线性的,所以整体复杂度是线性的。

因此,对于简单的情况(元素的奇数个),您会得到如下内容:

auto pos = x.begin() + x.size()/2;

std::nth_element(x.begin(), pos, x.end());
return *pos;

...对于更复杂的情况(偶数个元素):

std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;

here 可以期望 max_element( ForwardIt first, ForwardIt last ) 从头到尾提供最大值,但是通过这样做: return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2. x.begin() + n 元素似乎被排除在外计算。为什么会出现这种差异?

例如cpp_med2({6, 2, 1, 5, 3, 4}) 产生 x={2, 1, 3, 4, 5, 6} 其中:

n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3

所以 cpp_med2({6, 2, 1, 5, 3, 4}) returns (4+3)/2=3.5 这是正确的中位数。 但是为什么 *std::max_element(x.begin(), x.begin() + n) 等于 3 而不是 4?该函数实际上似乎排除了最大值计算中的最后一个元素 (4)。

已解决(我认为):在:

Finds the greatest element in the range [first, last)

)关闭意味着最后一个被排除在计算之外。对吗?

此致