如何对 macro_rules 的参数对进行操作?

How can I operate on pairs of arguments to macro_rules?

我正在编译时使用 const 泛型和宏构建一个简单的前馈神经网络。这些是一个接一个的一堆矩阵。

我创建了 network! 宏,其工作方式如下:

network!(2, 4, 1)

第一项为输入个数,其余为每层神经元个数。宏如下所示:

#[macro_export]
macro_rules! network {
    ( $inputs:expr, $($outputs:expr),* ) => {
        {
            Network {
                layers: [
                    $(
                        &Layer::<$inputs, $outputs>::new(),
                    )*
                ]
            }
        }
    };
}

它声明了一个层元素数组,它使用常量泛型在每一层上有一个固定大小的权重数组,第一个类型参数是它期望的输入数量,第二个类型参数是输入的数量输出。

此宏生成以下代码:

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<2, 1>::new(),
    ]
}

这是完全错误的,因为每一层的输入数量应该是前一层的输出数量,就像这样(注意 2 -> 4):

Network {
    layers: [
         &Layer::<2, 4>::new(),
         &Layer::<4, 1>::new(),
    ]
}

为此,我需要在每次迭代时用 $outputs 的值替换 $inputs 值,但我不知道该怎么做。

您可以匹配 两个 前导值,然后匹配所有其他值。对这两个值做一些特定的事情并递归调用宏,重用第二个值:

struct Layer<const I: usize, const O: usize>;

macro_rules! example {
    // Do something interesting for a given pair of arguments
    ($a:literal, $b:literal) => {
        Layer::<$a, $b>;
    };

    // Recursively traverse the arguments
    ($a:literal, $b:literal, $($rest:literal),+) => {
        example!($a, $b);
        example!($b, $($rest),*);
    };
}

fn main() {
    example!(1, 2, 3);
}

扩展宏导致:

fn main() {
    Layer::<1, 2>;
    Layer::<2, 3>;
}

对于那些感兴趣的人,我终于能够根据@Shepmaster 的回答像这样填充我的网络:

struct Network<'a, const L: usize> {
    layers: [&'a dyn Forward; L],
}

macro_rules! network {
    // Recursively accumulate token tree
    (@accum ($a:literal, $b:literal, $($others:literal),+) $($e:tt)*) => {
        network!(@accum ($b, $($others),*) $($e)*, &Layer::<$a, $b>::new())
    };

    // Latest iteration, convert to expression
    (@accum ($a:literal, $b:literal) $($e:tt)*) => {[$($e)*, &Layer::<$a, $b>::new()]};

    // Entrance
    ($a:literal, $b:literal, $($others:literal),+) => {
        Network {
            layers: network!(@accum ($b, $($others),*) &Layer::<$a, $b>::new())
        }
    };
}

对于 network!(2, 3, 4, 5, 1) 它转换为:

Network {
     layers:
          [&Layer::<2, 3>::new(),
           &Layer::<3, 4>::new(),
           &Layer::<4, 5>::new(),
           &Layer::<5, 1>::new()]
};