如何用 Rust 宏简化数学公式?
How to simplify mathematical formulas with rust macros?
我必须承认我对宏有点迷茫。
我想构建一个执行以下任务的宏,并且
我不知道该怎么做。我想执行标量积
两个数组,比如 x 和 y,它们具有相同的长度 N。
我要计算的结果是以下形式:
z = sum_{i=0}^{N-1} x[i] * y[i].
x
是 const
哪些元素是 0, 1, or -1
这在编译时是已知的,
而 y
的元素是在运行时确定的。因为
x
的结构,很多计算都是无用的(项乘以0
可以从和中去除,1 * y[i], -1 * y[i]
形式的乘法可以分别转化为y[i], -y[i]
)。
例如,如果 x = [-1, 1, 0]
,上面的标量积将是
z=-1 * y[0] + 1 * y[1] + 0 * y[2]
为了加快计算速度,我可以手动展开循环并重写
整个事情没有 x[i]
,我可以将上面的公式硬编码为
z = -y[0] + y[1]
但是这个程序不够优雅,容易出错
当 N 变大时非常乏味。
我很确定我可以用宏来做到这一点,但我不知道在哪里
开始(我读过的不同的书都没有深入到宏和
我卡住了)...
你们中的任何人都知道如何(如果可能的话)使用宏来解决这个问题吗?
提前感谢您的帮助!
编辑: 正如许多答案中所指出的,编译器足够聪明,可以在整数的情况下删除优化循环。我不仅使用整数而且还使用浮点数(x
数组是 i32s,但通常 y
是 f64
s),所以编译器不够聪明(理所当然)优化循环。下面的一段代码给出了下面的 asm.
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
xorpd %xmm0, %xmm0
movsd (%rdi), %xmm1
mulsd %xmm0, %xmm1
addsd %xmm0, %xmm1
addsd 8(%rdi), %xmm1
subsd 16(%rdi), %xmm1
movupd 24(%rdi), %xmm2
xorpd %xmm3, %xmm3
mulpd %xmm2, %xmm3
addsd %xmm3, %xmm1
unpckhpd %xmm3, %xmm3
addsd %xmm1, %xmm3
addsd 40(%rdi), %xmm3
mulsd 48(%rdi), %xmm0
addsd %xmm3, %xmm0
subsd 56(%rdi), %xmm0
retq
您可以使用 returns 函数的宏来实现您的目标。
首先,在没有宏的情况下编写这个函数。这个接受固定数量的参数。
fn main() {
println!("Hello, world!");
let func = gen_sum([1,2,3]);
println!("{}", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}
fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32 {
move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
}
现在,完全重写它,因为先前的设计不能很好地用作宏。我们不得不放弃固定大小的数组,如 macros appear unable to allocate fixed-sized arrays.
fn main() {
let func = gen_sum!(1,2,3);
println!("{}", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}
#[macro_export]
macro_rules! gen_sum {
( $( $x:expr ),* ) => {
{
let mut xs = Vec::new();
$(
xs.push($x);
)*
move |ys:Vec<i32>| {
if xs.len() != ys.len() {
panic!("lengths don't match")
}
let mut total = 0;
for i in 0 as usize .. xs.len() {
total += xs[i] * ys[i];
}
total
}
}
};
}
这个do/What应该做什么
在编译时,它生成一个 lambda。此 lambda 接受数字列表并将其乘以在编译时生成的 vec。我不认为这正是您所追求的,因为它不会在编译时优化掉零。您可以在编译时优化掉零,但是您必须在 运行 时间检查零在 x 中的位置以确定在 y 中乘以哪些元素,这必然会产生一些成本。您甚至可以使用哈希集在恒定时间内完成此查找过程。一般来说,它仍然可能不值得(我认为 0 并不那么常见)。计算机更擅长做一件 "inefficient" 的事情,而不是检测他们将要做的事情是 "inefficient" 然后跳过那件事。当他们所做的大部分操作是 "inefficient"
时,这种抽象就会崩溃
跟进
这值得吗?它提高了 运行 倍吗?我没有测量,但与仅使用函数相比,理解和维护我编写的宏似乎不值得。编写一个执行您所说的零优化的宏可能会更不愉快。
在很多情况下,编译器的优化阶段会为您解决这个问题。举个例子,这个函数定义
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [i32; 8]) -> i32 {
X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
}
在 x86_64:
上产生此程序集输出
playground::dot_x:
mov eax, dword ptr [rdi + 4]
sub eax, dword ptr [rdi + 8]
add eax, dword ptr [rdi + 20]
sub eax, dword ptr [rdi + 28]
ret
您将无法获得比这更优化的版本,因此简单地以天真的方式编写代码是最好的解决方案。编译器是否会为更长的向量展开循环尚不清楚,它可能会随着编译器版本的不同而改变。
对于浮点数,编译器通常无法执行上述所有优化,因为 y
中的数字不能保证是有限的——它们也可能是 NaN
, inf
或 -inf
。因此,与 0.0
相乘并不能保证再次得到 0.0
,因此编译器需要在代码中保留乘法指令。但是,您可以通过使用 fmul_fast()
内部函数明确允许它假设所有数字都是有限的:
#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
X.iter().zip(y.iter()).map(|(i, j)| unsafe { fmul_fast(*i as f64, *j) }).sum()
}
这导致以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
xorpd xmm1, xmm1
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, xmm1
addsd xmm0, qword ptr [rdi + 40]
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 56]
ret
这仍然会在步骤之间冗余地添加零,但我不希望这会导致实际 CFD 模拟的任何可测量的开销,因为此类模拟往往受内存带宽而不是 CPU 的限制。如果你也想避免这些添加,你需要使用fadd_fast()
进行添加,让编译器进一步优化:
#![feature(core_intrinsics)]
use std::intrinsics::{fadd_fast, fmul_fast};
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
let mut result = 0.0;
for (&i, &j) in X.iter().zip(y.iter()) {
unsafe { result = fadd_fast(result, fmul_fast(i as f64, j)); }
}
result
}
这导致以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, qword ptr [rdi + 40]
subsd xmm0, qword ptr [rdi + 56]
ret
与所有优化一样,您应该从最易读和可维护的代码版本开始。如果性能成为问题,您应该分析您的代码并找到瓶颈。下一步,尝试改进基本方法,例如通过使用具有更好渐近复杂度的算法。只有这样你才应该转向微优化,就像你在问题中建议的那样。
如果你可以节省 #[inline(always)]
可能使用显式 filter_map() 应该足以让编译器做你想做的事。
首先,(proc) 宏不能简单地查看数组内部 x
。它得到的只是你传递给它的令牌,没有任何上下文。如果你想让它知道值 (0, 1, -1),你需要将它们直接传递给你的宏:
let result = your_macro!(y, -1, 0, 1, -1);
但是您实际上并不需要为此使用宏。编译器进行了很多优化,如其他答案所示。但是,正如您在编辑中已经提到的那样,它不会优化掉 0.0 * x[i]
,因为结果并不总是 0.0
。 (例如,它可以是 -0.0
或 NaN
。)我们在这里可以做的,只是通过使用 match
或 if
来帮助优化器,以确保它对 0.0 * y
案例没有任何作用:
const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];
fn foobar(y: [f64; 8]) -> f64 {
let mut sum = 0.0;
for (&x, &y) in X.iter().zip(&y) {
if x != 0 {
sum += x as f64 * y;
}
}
sum
}
在发布模式下,展开循环并内联 X
的值,导致大多数迭代被丢弃,因为它们什么都不做。结果二进制文件中唯一剩下的东西(在 x86_64 上)是:
foobar:
xorpd xmm0, xmm0
subsd xmm0, qword, ptr, [rdi, +, 8]
addsd xmm0, qword, ptr, [rdi, +, 48]
ret
(As suggested by @lu-zero, this can also be done using filter_map
. That will look like this: X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum()
, and gives the exact same generated assembly. Or even without a match
, by using filter
and map
separately: .filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum()
.)
不错!但是,此函数计算 0.0 - y[1] + y[6]
,因为 sum
从 0.0
开始,我们只对其进行减法和加法。优化器再次不愿意优化掉 0.0
。我们可以通过不从 0.0
开始,而是从 None
:
开始来帮助它
fn foobar(y: [f64; 8]) -> f64 {
let mut sum = None;
for (&x, &y) in X.iter().zip(&y) {
if x != 0 {
let p = x as f64 * y;
sum = Some(sum.map_or(p, |s| s + p));
}
}
sum.unwrap_or(0.0)
}
这导致:
foobar:
movsd xmm0, qword, ptr, [rdi, +, 48]
subsd xmm0, qword, ptr, [rdi, +, 8]
ret
这只是 y[6] - y[1]
。宾果!
我必须承认我对宏有点迷茫。 我想构建一个执行以下任务的宏,并且 我不知道该怎么做。我想执行标量积 两个数组,比如 x 和 y,它们具有相同的长度 N。 我要计算的结果是以下形式:
z = sum_{i=0}^{N-1} x[i] * y[i].
x
是 const
哪些元素是 0, 1, or -1
这在编译时是已知的,
而 y
的元素是在运行时确定的。因为
x
的结构,很多计算都是无用的(项乘以0
可以从和中去除,1 * y[i], -1 * y[i]
形式的乘法可以分别转化为y[i], -y[i]
)。
例如,如果 x = [-1, 1, 0]
,上面的标量积将是
z=-1 * y[0] + 1 * y[1] + 0 * y[2]
为了加快计算速度,我可以手动展开循环并重写
整个事情没有 x[i]
,我可以将上面的公式硬编码为
z = -y[0] + y[1]
但是这个程序不够优雅,容易出错 当 N 变大时非常乏味。
我很确定我可以用宏来做到这一点,但我不知道在哪里 开始(我读过的不同的书都没有深入到宏和 我卡住了)...
你们中的任何人都知道如何(如果可能的话)使用宏来解决这个问题吗?
提前感谢您的帮助!
编辑: 正如许多答案中所指出的,编译器足够聪明,可以在整数的情况下删除优化循环。我不仅使用整数而且还使用浮点数(x
数组是 i32s,但通常 y
是 f64
s),所以编译器不够聪明(理所当然)优化循环。下面的一段代码给出了下面的 asm.
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
xorpd %xmm0, %xmm0
movsd (%rdi), %xmm1
mulsd %xmm0, %xmm1
addsd %xmm0, %xmm1
addsd 8(%rdi), %xmm1
subsd 16(%rdi), %xmm1
movupd 24(%rdi), %xmm2
xorpd %xmm3, %xmm3
mulpd %xmm2, %xmm3
addsd %xmm3, %xmm1
unpckhpd %xmm3, %xmm3
addsd %xmm1, %xmm3
addsd 40(%rdi), %xmm3
mulsd 48(%rdi), %xmm0
addsd %xmm3, %xmm0
subsd 56(%rdi), %xmm0
retq
您可以使用 returns 函数的宏来实现您的目标。
首先,在没有宏的情况下编写这个函数。这个接受固定数量的参数。
fn main() {
println!("Hello, world!");
let func = gen_sum([1,2,3]);
println!("{}", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}
fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32 {
move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
}
现在,完全重写它,因为先前的设计不能很好地用作宏。我们不得不放弃固定大小的数组,如 macros appear unable to allocate fixed-sized arrays.
fn main() {
let func = gen_sum!(1,2,3);
println!("{}", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}
#[macro_export]
macro_rules! gen_sum {
( $( $x:expr ),* ) => {
{
let mut xs = Vec::new();
$(
xs.push($x);
)*
move |ys:Vec<i32>| {
if xs.len() != ys.len() {
panic!("lengths don't match")
}
let mut total = 0;
for i in 0 as usize .. xs.len() {
total += xs[i] * ys[i];
}
total
}
}
};
}
这个do/What应该做什么
在编译时,它生成一个 lambda。此 lambda 接受数字列表并将其乘以在编译时生成的 vec。我不认为这正是您所追求的,因为它不会在编译时优化掉零。您可以在编译时优化掉零,但是您必须在 运行 时间检查零在 x 中的位置以确定在 y 中乘以哪些元素,这必然会产生一些成本。您甚至可以使用哈希集在恒定时间内完成此查找过程。一般来说,它仍然可能不值得(我认为 0 并不那么常见)。计算机更擅长做一件 "inefficient" 的事情,而不是检测他们将要做的事情是 "inefficient" 然后跳过那件事。当他们所做的大部分操作是 "inefficient"
时,这种抽象就会崩溃跟进
这值得吗?它提高了 运行 倍吗?我没有测量,但与仅使用函数相比,理解和维护我编写的宏似乎不值得。编写一个执行您所说的零优化的宏可能会更不愉快。
在很多情况下,编译器的优化阶段会为您解决这个问题。举个例子,这个函数定义
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [i32; 8]) -> i32 {
X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
}
在 x86_64:
上产生此程序集输出playground::dot_x:
mov eax, dword ptr [rdi + 4]
sub eax, dword ptr [rdi + 8]
add eax, dword ptr [rdi + 20]
sub eax, dword ptr [rdi + 28]
ret
您将无法获得比这更优化的版本,因此简单地以天真的方式编写代码是最好的解决方案。编译器是否会为更长的向量展开循环尚不清楚,它可能会随着编译器版本的不同而改变。
对于浮点数,编译器通常无法执行上述所有优化,因为 y
中的数字不能保证是有限的——它们也可能是 NaN
, inf
或 -inf
。因此,与 0.0
相乘并不能保证再次得到 0.0
,因此编译器需要在代码中保留乘法指令。但是,您可以通过使用 fmul_fast()
内部函数明确允许它假设所有数字都是有限的:
#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
X.iter().zip(y.iter()).map(|(i, j)| unsafe { fmul_fast(*i as f64, *j) }).sum()
}
这导致以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
xorpd xmm1, xmm1
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, xmm1
addsd xmm0, qword ptr [rdi + 40]
addsd xmm0, xmm1
subsd xmm0, qword ptr [rdi + 56]
ret
这仍然会在步骤之间冗余地添加零,但我不希望这会导致实际 CFD 模拟的任何可测量的开销,因为此类模拟往往受内存带宽而不是 CPU 的限制。如果你也想避免这些添加,你需要使用fadd_fast()
进行添加,让编译器进一步优化:
#![feature(core_intrinsics)]
use std::intrinsics::{fadd_fast, fmul_fast};
const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];
pub fn dot_x(y: [f64; 8]) -> f64 {
let mut result = 0.0;
for (&i, &j) in X.iter().zip(y.iter()) {
unsafe { result = fadd_fast(result, fmul_fast(i as f64, j)); }
}
result
}
这导致以下汇编代码:
playground::dot_x: # @playground::dot_x
# %bb.0:
movsd xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
subsd xmm0, qword ptr [rdi + 16]
addsd xmm0, qword ptr [rdi + 40]
subsd xmm0, qword ptr [rdi + 56]
ret
与所有优化一样,您应该从最易读和可维护的代码版本开始。如果性能成为问题,您应该分析您的代码并找到瓶颈。下一步,尝试改进基本方法,例如通过使用具有更好渐近复杂度的算法。只有这样你才应该转向微优化,就像你在问题中建议的那样。
如果你可以节省 #[inline(always)]
可能使用显式 filter_map() 应该足以让编译器做你想做的事。
首先,(proc) 宏不能简单地查看数组内部 x
。它得到的只是你传递给它的令牌,没有任何上下文。如果你想让它知道值 (0, 1, -1),你需要将它们直接传递给你的宏:
let result = your_macro!(y, -1, 0, 1, -1);
但是您实际上并不需要为此使用宏。编译器进行了很多优化,如其他答案所示。但是,正如您在编辑中已经提到的那样,它不会优化掉 0.0 * x[i]
,因为结果并不总是 0.0
。 (例如,它可以是 -0.0
或 NaN
。)我们在这里可以做的,只是通过使用 match
或 if
来帮助优化器,以确保它对 0.0 * y
案例没有任何作用:
const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];
fn foobar(y: [f64; 8]) -> f64 {
let mut sum = 0.0;
for (&x, &y) in X.iter().zip(&y) {
if x != 0 {
sum += x as f64 * y;
}
}
sum
}
在发布模式下,展开循环并内联 X
的值,导致大多数迭代被丢弃,因为它们什么都不做。结果二进制文件中唯一剩下的东西(在 x86_64 上)是:
foobar:
xorpd xmm0, xmm0
subsd xmm0, qword, ptr, [rdi, +, 8]
addsd xmm0, qword, ptr, [rdi, +, 48]
ret
(As suggested by @lu-zero, this can also be done using
filter_map
. That will look like this:X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum()
, and gives the exact same generated assembly. Or even without amatch
, by usingfilter
andmap
separately:.filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum()
.)
不错!但是,此函数计算 0.0 - y[1] + y[6]
,因为 sum
从 0.0
开始,我们只对其进行减法和加法。优化器再次不愿意优化掉 0.0
。我们可以通过不从 0.0
开始,而是从 None
:
fn foobar(y: [f64; 8]) -> f64 {
let mut sum = None;
for (&x, &y) in X.iter().zip(&y) {
if x != 0 {
let p = x as f64 * y;
sum = Some(sum.map_or(p, |s| s + p));
}
}
sum.unwrap_or(0.0)
}
这导致:
foobar:
movsd xmm0, qword, ptr, [rdi, +, 48]
subsd xmm0, qword, ptr, [rdi, +, 8]
ret
这只是 y[6] - y[1]
。宾果!