如何用 Rust 宏简化数学公式?

How to simplify mathematical formulas with rust macros?

我必须承认我对宏有点迷茫。 我想构建一个执行以下任务的宏,并且 我不知道该怎么做。我想执行标量积 两个数组,比如 x 和 y,它们具有相同的长度 N。 我要计算的结果是以下形式:

z = sum_{i=0}^{N-1} x[i] * y[i].

xconst 哪些元素是 0, 1, or -1 这在编译时是已知的, 而 y 的元素是在运行时确定的。因为 x的结构,很多计算都是无用的(项乘以0 可以从和中去除,1 * y[i], -1 * y[i]形式的乘法可以分别转化为y[i], -y[i])。

例如,如果 x = [-1, 1, 0],上面的标量积将是

z=-1 * y[0] + 1 * y[1] + 0 * y[2]

为了加快计算速度,我可以手动展开循环并重写 整个事情没有 x[i],我可以将上面的公式硬编码为

z = -y[0] + y[1]

但是这个程序不够优雅,容易出错 当 N 变大时非常乏味。

我很确定我可以用宏来做到这一点,但我不知道在哪里 开始(我读过的不同的书都没有深入到宏和 我卡住了)...

你们中的任何人都知道如何(如果可能的话)使用宏来解决这个问题吗?

提前感谢您的帮助!

编辑: 正如许多答案中所指出的,编译器足够聪明,可以在整数的情况下删除优化循环。我不仅使用整数而且还使用浮点数(x 数组是 i32s,但通常 yf64s),所以编译器不够聪明(理所当然)优化循环。下面的一段代码给出了下面的 asm.

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| (*i as f64) * j).sum()
}
playground::dot_x:
    xorpd   %xmm0, %xmm0
    movsd   (%rdi), %xmm1
    mulsd   %xmm0, %xmm1
    addsd   %xmm0, %xmm1
    addsd   8(%rdi), %xmm1
    subsd   16(%rdi), %xmm1
    movupd  24(%rdi), %xmm2
    xorpd   %xmm3, %xmm3
    mulpd   %xmm2, %xmm3
    addsd   %xmm3, %xmm1
    unpckhpd    %xmm3, %xmm3
    addsd   %xmm1, %xmm3
    addsd   40(%rdi), %xmm3
    mulsd   48(%rdi), %xmm0
    addsd   %xmm3, %xmm0
    subsd   56(%rdi), %xmm0
    retq

您可以使用 returns 函数的宏来实现您的目标。

首先,在没有宏的情况下编写这个函数。这个接受固定数量的参数。

fn main() {
    println!("Hello, world!");
    let func = gen_sum([1,2,3]);
    println!("{}", func([4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

fn gen_sum(xs: [i32; 3]) -> impl Fn([i32;3]) -> i32 {
    move |ys| ys[0]*xs[0] + ys[1]*xs[1] + ys[2]*xs[2]
}

现在,完全重写它,因为先前的设计不能很好地用作宏。我们不得不放弃固定大小的数组,如 macros appear unable to allocate fixed-sized arrays.

Rust Playground

fn main() {
    let func = gen_sum!(1,2,3);
    println!("{}", func(vec![4,5,6])) // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
}

#[macro_export]
macro_rules! gen_sum {
    ( $( $x:expr ),* ) => {
        {
            let mut xs = Vec::new();
            $(
                xs.push($x);
            )*
            move |ys:Vec<i32>| {
                if xs.len() != ys.len() {
                    panic!("lengths don't match")
                }
                let mut total = 0;
                for i in 0 as usize .. xs.len() {
                    total += xs[i] * ys[i];
                }
                total
            } 
        }
    };
}

这个do/What应该做什么

在编译时,它生成一个 lambda。此 lambda 接受数字列表并将其乘以在编译时生成的 vec。我不认为这正是您所追求的,因为它不会在编译时优化掉零。您可以在编译时优化掉零,但是您必须在 运行 时间检查零在 x 中的位置以确定在 y 中乘以哪些元素,这必然会产生一些成本。您甚至可以使用哈希集在恒定时间内完成此查找过程。一般来说,它仍然可能不值得(我认为 0 并不那么常见)。计算机更擅长做一件 "inefficient" 的事情,而不是检测他们将要做的事情是 "inefficient" 然后跳过那件事。当他们所做的大部分操作是 "inefficient"

时,这种抽象就会崩溃

跟进

这值得吗?它提高了 运行 倍吗?我没有测量,但与仅使用函数相比,理解和维护我编写的宏似乎不值得。编写一个执行您所说的零优化的宏可能会更不愉快。

在很多情况下,编译器的优化阶段会为您解决这个问题。举个例子,这个函数定义

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [i32; 8]) -> i32 {
    X.iter().zip(y.iter()).map(|(i, j)| i * j).sum()
}

在 x86_64:

上产生此程序集输出
playground::dot_x:
    mov eax, dword ptr [rdi + 4]
    sub eax, dword ptr [rdi + 8]
    add eax, dword ptr [rdi + 20]
    sub eax, dword ptr [rdi + 28]
    ret

您将无法获得比这更优化的版本,因此简单地以天真的方式编写代码是最好的解决方案。编译器是否会为更长的向量展开循环尚不清楚,它可能会随着编译器版本的不同而改变。

对于浮点数,编译器通常无法执行上述所有优化,因为 y 中的数字不能保证是有限的——它们也可能是 NaNinf-inf。因此,与 0.0 相乘并不能保证再次得到 0.0,因此编译器需要在代码中保留乘法指令。但是,您可以通过使用 fmul_fast() 内部函数明确允许它假设所有数字都是有限的:

#![feature(core_intrinsics)]
use std::intrinsics::fmul_fast;

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    X.iter().zip(y.iter()).map(|(i, j)| unsafe { fmul_fast(*i as f64, *j) }).sum()
}

这导致以下汇编代码:

playground::dot_x: # @playground::dot_x
# %bb.0:
    xorpd   xmm1, xmm1
    movsd   xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
    addsd   xmm0, xmm1
    subsd   xmm0, qword ptr [rdi + 16]
    addsd   xmm0, xmm1
    addsd   xmm0, qword ptr [rdi + 40]
    addsd   xmm0, xmm1
    subsd   xmm0, qword ptr [rdi + 56]
    ret

这仍然会在步骤之间冗余地添加零,但我不希望这会导致实际 CFD 模拟的任何可测量的开销,因为此类模拟往往受内存带宽而不是 CPU 的限制。如果你也想避免这些添加,你需要使用fadd_fast()进行添加,让编译器进一步优化:

#![feature(core_intrinsics)]
use std::intrinsics::{fadd_fast, fmul_fast};

const X: [i32; 8] = [0, 1, -1, 0, 0, 1, 0, -1];

pub fn dot_x(y: [f64; 8]) -> f64 {
    let mut result = 0.0;
    for (&i, &j) in X.iter().zip(y.iter()) {
        unsafe { result = fadd_fast(result, fmul_fast(i as f64, j)); }
    }
    result
}

这导致以下汇编代码:

playground::dot_x: # @playground::dot_x
# %bb.0:
    movsd   xmm0, qword ptr [rdi + 8] # xmm0 = mem[0],zero
    subsd   xmm0, qword ptr [rdi + 16]
    addsd   xmm0, qword ptr [rdi + 40]
    subsd   xmm0, qword ptr [rdi + 56]
    ret

与所有优化一样,您应该从最易读和可维护的代码版本开始。如果性能成为问题,您应该分析您的代码并找到瓶颈。下一步,尝试改进基本方法,例如通过使用具有更好渐近复杂度的算法。只有这样你才应该转向微优化,就像你在问题中建议的那样。

如果你可以节省 #[inline(always)] 可能使用显式 filter_map() 应该足以让编译器做你想做的事。

首先,(proc) 宏不能简单地查看数组内部 x。它得到的只是你传递给它的令牌,没有任何上下文。如果你想让它知道值 (0, 1, -1),你需要将它们直接传递给你的宏:

let result = your_macro!(y, -1, 0, 1, -1);

但是您实际上并不需要为此使用宏。编译器进行了很多优化,如其他答案所示。但是,正如您在编辑中已经提到的那样,它不会优化掉 0.0 * x[i],因为结果并不总是 0.0。 (例如,它可以是 -0.0NaN。)我们在这里可以做的,只是通过使用 matchif 来帮助优化器,以确保它对 0.0 * y 案例没有任何作用:

const X: [i32; 8] = [0, -1, 0, 0, 0, 0, 1, 0];

fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = 0.0;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            sum += x as f64 * y;
        }
    }
    sum
}

在发布模式下,展开循环并内联 X 的值,导致大多数迭代被丢弃,因为它们什么都不做。结果二进制文件中唯一剩下的东西(在 x86_64 上)是:

foobar:
 xorpd   xmm0, xmm0
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 addsd   xmm0, qword, ptr, [rdi, +, 48]
 ret

(As suggested by @lu-zero, this can also be done using filter_map. That will look like this: X.iter().zip(&y).filter_map(|(&x, &y)| match x { 0 => None, _ => Some(x as f64 * y) }).sum(), and gives the exact same generated assembly. Or even without a match, by using filter and map separately: .filter(|(&x, _)| x != 0).map(|(&x, &y)| x as f64 * y).sum().)

不错!但是,此函数计算 0.0 - y[1] + y[6],因为 sum0.0 开始,我们只对其进行减法和加法。优化器再次不愿意优化掉 0.0。我们可以通过不从 0.0 开始,而是从 None:

开始来帮助它
fn foobar(y: [f64; 8]) -> f64 {
    let mut sum = None;
    for (&x, &y) in X.iter().zip(&y) {
        if x != 0 {
            let p = x as f64 * y;
            sum = Some(sum.map_or(p, |s| s + p));
        }
    }
    sum.unwrap_or(0.0)
}

这导致:

foobar:
 movsd   xmm0, qword, ptr, [rdi, +, 48]
 subsd   xmm0, qword, ptr, [rdi, +, 8]
 ret

这只是 y[6] - y[1]。宾果!