为什么标准 R 中值函数比简单的 C++ 替代函数慢得多?
Why is standard R median function so much slower than a simple C++ alternative?
我在 C++
中实现了以下中位数,并通过 Rcpp
在 R
中使用了它:
// [[Rcpp::export]]
double median2(std::vector<double> x){
double median;
size_t size = x.size();
sort(x.begin(), x.end());
if (size % 2 == 0){
median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
}
else {
median = x[size / 2];
}
return median;
}
如果我随后将性能与标准内置 R 中值函数进行比较,我通过 microbenchmark
得到以下结果
> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
expr min lq mean median uq max neval
median(x) 25.469 26.990 34.96888 28.130 29.081 518.126 100
median2(x) 1.140 1.521 2.47486 1.901 2.281 47.897 100
为什么标准中值函数这么慢?这不是我所期望的...
我不确定您指的 "standard" 实现是什么。
无论如何:如果有的话,作为标准库的一部分,当然不允许更改向量中元素的顺序(正如您的实现所做的那样),因此它肯定必须继续工作复印件。
创建此副本需要时间和 CPU(以及大量内存),这会影响 运行 时间。
正如@joran 所指出的,您的代码非常专业,一般来说,不太通用的函数、算法等...通常性能更高。看看median.default
:
median.default
# function (x, na.rm = FALSE)
# {
# if (is.factor(x) || is.data.frame(x))
# stop("need numeric data")
# if (length(names(x)))
# names(x) <- NULL
# if (na.rm)
# x <- x[!is.na(x)]
# else if (any(is.na(x)))
# return(x[FALSE][NA])
# n <- length(x)
# if (n == 0L)
# return(x[FALSE][NA])
# half <- (n + 1L)%/%2L
# if (n%%2L == 1L)
# sort(x, partial = half)[half]
# else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }
有几种操作可以适应缺失值的可能性,这些操作肯定会影响函数的整体执行时间。由于您的函数不会复制此行为,因此它可以消除一堆计算,但因此不会为具有缺失值的向量提供相同的结果:
median(c(1, 2, NA))
#[1] NA
median2(c(1, 2, NA))
#[1] 2
其他一些因素可能没有 处理 NA
的影响,但值得指出:
median
,连同它使用的一些函数,都是 S3 泛型,所以在方法调度上花费了少量时间
median
不仅仅适用于整数和数字向量;它还将处理 Date
、POSIXt
,可能还有一堆其他 类,并正确保留属性:
median(Sys.Date() + 0:4)
#[1] "2016-01-15"
median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"
编辑:
我应该提到下面的函数 将导致原始向量被排序 因为 NumericVector
s 是代理对象。如果你想避免这种情况,你可以 Rcpp::clone
输入向量并对克隆进行操作,或者使用你的原始签名(带有 std::vector<double>
),这在从 [= 的转换中隐含地需要一个副本26=] 到 std::vector
。
另请注意,使用 NumericVector
代替 std::vector<double>
可以节省更多时间:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
std::size_t size = x.size();
std::sort(x.begin(), x.end());
if (size % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
return x[size / 2];
}
microbenchmark::microbenchmark(
median(x),
median2(x),
cpp_med(x),
times = 200L
)
# Unit: microseconds
# expr min lq mean median uq max neval
# median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810 200
# median2(x) 6.474 7.9665 13.90126 11.0570 14.844 151.817 200
# cpp_med(x) 5.737 7.4285 11.25318 9.0270 13.405 52.184 200
Yakk 在上面的评论中提出了一个重要观点——Jerry Coffin 也对此进行了详细阐述——关于进行完整排序的效率低下。这是使用 std::nth_element
重写的,以更大的向量为基准:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
Rcpp::NumericVector x = Rcpp::clone(xx);
std::size_t n = x.size() / 2;
std::nth_element(x.begin(), x.begin() + n, x.end());
if (x.size() % 2) return x[n];
return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}
set.seed(123)
xx <- rnorm(10e5)
all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))
microbenchmark::microbenchmark(
cpp_med2(xx), median2(xx),
median(xx), times = 200L
)
# Unit: milliseconds
# expr min lq mean median uq max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161 33.92103 200
# median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301 200
# median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939 200
[这是对您实际提出的问题的回答,而不是扩展评论。]
甚至您的代码也可以进行重大改进。特别是,即使您只关心一两个元素,您也要对整个输入进行排序。
您可以使用 std::nth_element
而不是 std::sort
将其从 O(n log n) 更改为 O(n)。在元素数量为偶数的情况下,您通常希望使用 std::nth_element
来查找中间之前的元素,然后使用 std::min_element
来查找紧随其后的元素——但是 std::nth_element
还对输入项进行分区,因此 std::min_element
只需 运行 在 nth_element
之后中间以上的项上,而不是整个输入数组。也就是说,在nth_element之后,你会得到这样的情况:
std::nth_element
的复杂度是"linear on average",(当然)std::min_element
也是线性的,所以整体复杂度是线性的。
因此,对于简单的情况(元素的奇数个),您会得到如下内容:
auto pos = x.begin() + x.size()/2;
std::nth_element(x.begin(), pos, x.end());
return *pos;
...对于更复杂的情况(偶数个元素):
std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;
从 here 可以期望 max_element( ForwardIt first, ForwardIt last )
从头到尾提供最大值,但是通过这样做: return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.
x.begin() + n
元素似乎被排除在外计算。为什么会出现这种差异?
例如cpp_med2({6, 2, 1, 5, 3, 4})
产生 x={2, 1, 3, 4, 5, 6}
其中:
n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3
所以 cpp_med2({6, 2, 1, 5, 3, 4})
returns (4+3)/2=3.5 这是正确的中位数。
但是为什么 *std::max_element(x.begin(), x.begin() + n)
等于 3 而不是 4?该函数实际上似乎排除了最大值计算中的最后一个元素 (4)。
已解决(我认为):在:
Finds the greatest element in the range [first, last)
)
关闭意味着最后一个被排除在计算之外。对吗?
此致
我在 C++
中实现了以下中位数,并通过 Rcpp
在 R
中使用了它:
// [[Rcpp::export]]
double median2(std::vector<double> x){
double median;
size_t size = x.size();
sort(x.begin(), x.end());
if (size % 2 == 0){
median = (x[size / 2 - 1] + x[size / 2]) / 2.0;
}
else {
median = x[size / 2];
}
return median;
}
如果我随后将性能与标准内置 R 中值函数进行比较,我通过 microbenchmark
> x = rnorm(100)
> microbenchmark(median(x),median2(x))
Unit: microseconds
expr min lq mean median uq max neval
median(x) 25.469 26.990 34.96888 28.130 29.081 518.126 100
median2(x) 1.140 1.521 2.47486 1.901 2.281 47.897 100
为什么标准中值函数这么慢?这不是我所期望的...
我不确定您指的 "standard" 实现是什么。
无论如何:如果有的话,作为标准库的一部分,当然不允许更改向量中元素的顺序(正如您的实现所做的那样),因此它肯定必须继续工作复印件。
创建此副本需要时间和 CPU(以及大量内存),这会影响 运行 时间。
正如@joran 所指出的,您的代码非常专业,一般来说,不太通用的函数、算法等...通常性能更高。看看median.default
:
median.default
# function (x, na.rm = FALSE)
# {
# if (is.factor(x) || is.data.frame(x))
# stop("need numeric data")
# if (length(names(x)))
# names(x) <- NULL
# if (na.rm)
# x <- x[!is.na(x)]
# else if (any(is.na(x)))
# return(x[FALSE][NA])
# n <- length(x)
# if (n == 0L)
# return(x[FALSE][NA])
# half <- (n + 1L)%/%2L
# if (n%%2L == 1L)
# sort(x, partial = half)[half]
# else mean(sort(x, partial = half + 0L:1L)[half + 0L:1L])
# }
有几种操作可以适应缺失值的可能性,这些操作肯定会影响函数的整体执行时间。由于您的函数不会复制此行为,因此它可以消除一堆计算,但因此不会为具有缺失值的向量提供相同的结果:
median(c(1, 2, NA))
#[1] NA
median2(c(1, 2, NA))
#[1] 2
其他一些因素可能没有 处理 NA
的影响,但值得指出:
median
,连同它使用的一些函数,都是 S3 泛型,所以在方法调度上花费了少量时间median
不仅仅适用于整数和数字向量;它还将处理Date
、POSIXt
,可能还有一堆其他 类,并正确保留属性:
median(Sys.Date() + 0:4)
#[1] "2016-01-15"
median(Sys.time() + (0:4) * 3600 * 24)
#[1] "2016-01-15 11:14:31 EST"
编辑:
我应该提到下面的函数 将导致原始向量被排序 因为 NumericVector
s 是代理对象。如果你想避免这种情况,你可以 Rcpp::clone
输入向量并对克隆进行操作,或者使用你的原始签名(带有 std::vector<double>
),这在从 [= 的转换中隐含地需要一个副本26=] 到 std::vector
。
另请注意,使用 NumericVector
代替 std::vector<double>
可以节省更多时间:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med(Rcpp::NumericVector x){
std::size_t size = x.size();
std::sort(x.begin(), x.end());
if (size % 2 == 0) return (x[size / 2 - 1] + x[size / 2]) / 2.0;
return x[size / 2];
}
microbenchmark::microbenchmark(
median(x),
median2(x),
cpp_med(x),
times = 200L
)
# Unit: microseconds
# expr min lq mean median uq max neval
# median(x) 74.787 81.6485 110.09870 92.5665 129.757 293.810 200
# median2(x) 6.474 7.9665 13.90126 11.0570 14.844 151.817 200
# cpp_med(x) 5.737 7.4285 11.25318 9.0270 13.405 52.184 200
Yakk 在上面的评论中提出了一个重要观点——Jerry Coffin 也对此进行了详细阐述——关于进行完整排序的效率低下。这是使用 std::nth_element
重写的,以更大的向量为基准:
#include <Rcpp.h>
// [[Rcpp::export]]
double cpp_med2(Rcpp::NumericVector xx) {
Rcpp::NumericVector x = Rcpp::clone(xx);
std::size_t n = x.size() / 2;
std::nth_element(x.begin(), x.begin() + n, x.end());
if (x.size() % 2) return x[n];
return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.;
}
set.seed(123)
xx <- rnorm(10e5)
all.equal(cpp_med2(xx), median(xx))
all.equal(median2(xx), median(xx))
microbenchmark::microbenchmark(
cpp_med2(xx), median2(xx),
median(xx), times = 200L
)
# Unit: milliseconds
# expr min lq mean median uq max neval
# cpp_med2(xx) 10.89060 11.34894 13.15313 12.72861 13.56161 33.92103 200
# median2(xx) 84.29518 85.47184 88.57361 86.05363 87.70065 228.07301 200
# median(xx) 46.18976 48.36627 58.77436 49.31659 53.46830 250.66939 200
[这是对您实际提出的问题的回答,而不是扩展评论。]
甚至您的代码也可以进行重大改进。特别是,即使您只关心一两个元素,您也要对整个输入进行排序。
您可以使用 std::nth_element
而不是 std::sort
将其从 O(n log n) 更改为 O(n)。在元素数量为偶数的情况下,您通常希望使用 std::nth_element
来查找中间之前的元素,然后使用 std::min_element
来查找紧随其后的元素——但是 std::nth_element
还对输入项进行分区,因此 std::min_element
只需 运行 在 nth_element
之后中间以上的项上,而不是整个输入数组。也就是说,在nth_element之后,你会得到这样的情况:
std::nth_element
的复杂度是"linear on average",(当然)std::min_element
也是线性的,所以整体复杂度是线性的。
因此,对于简单的情况(元素的奇数个),您会得到如下内容:
auto pos = x.begin() + x.size()/2;
std::nth_element(x.begin(), pos, x.end());
return *pos;
...对于更复杂的情况(偶数个元素):
std::nth_element(x.begin(), pos, x.end());
auto pos2 = std::min_element(pos+1, x.end());
return (*pos + *pos2) / 2.0;
从 here 可以期望 max_element( ForwardIt first, ForwardIt last )
从头到尾提供最大值,但是通过这样做: return (x[n] + *std::max_element(x.begin(), x.begin() + n)) / 2.
x.begin() + n
元素似乎被排除在外计算。为什么会出现这种差异?
例如cpp_med2({6, 2, 1, 5, 3, 4})
产生 x={2, 1, 3, 4, 5, 6}
其中:
n = 3
*x[n] = 4
*x.begin() = 2
*(x.begin() + n) = 4
*std::max_element(x.begin(), x.begin() + n) = 3
所以 cpp_med2({6, 2, 1, 5, 3, 4})
returns (4+3)/2=3.5 这是正确的中位数。
但是为什么 *std::max_element(x.begin(), x.begin() + n)
等于 3 而不是 4?该函数实际上似乎排除了最大值计算中的最后一个元素 (4)。
已解决(我认为):在:
Finds the greatest element in the range [first, last)
)
关闭意味着最后一个被排除在计算之外。对吗?
此致