用于按索引从嵌套结构中检索数据的过程宏

Procedural macro for retrieving data from a nested struct by index

我正在尝试编写一个 rust 派生宏,用于按索引从嵌套结构中检索数据。该结构仅包含原始类型 u8、i8、u16、i16、u32、i32、u64、i64 或其其他结构。我有一个枚举,它将叶字段数据封装在一个我称之为 Item() 的通用类型中。我希望宏创建一个 .get() 实现,其中 returns 一个基于 u16 索引的项目。

这是期望的行为。

#[derive(Debug, PartialEq, PartialOrd, Copy, Clone)]
pub enum Item {
    U8(u8),
    I8(i8),
    U16(u16),
    I16(i16),
    U32(u32),
    I32(i32),
    U64(u64),
    I64(i64),
}

struct NestedData {
    a: u16,
    b: i32,
}

#[derive(GetItem)]
struct Data {
    a: i32,
    b: u64,
    c: NestedData,
}

let data = Data {
        a: 42,
        b: 1000,
        c: NestedData { a: 500, b: -2 },
};

assert_eq!(data.get(0).unwrap(), Item::I32(42));
assert_eq!(data.get(1).unwrap(), Item::U64(1000));
assert_eq!(data.get(2).unwrap(), Item::U16(500));
assert_eq!(data.get(3).unwrap(), Item::I32(-2));

对于这个特定示例,我希望宏扩展为以下...

impl Data {
    pub fn get(&self, index: u16) -> Result<Item, Error> {
        match index {
            0 => Ok(Item::U16(self.a)),
            1 => Ok(Item::I32(self.b)),
            2 => Ok(Item::I32(self.c.a)),
            3 => Ok(Item::U64(self.c.b)),
            _ => Err(Error::BadIndex),
        }
    }
}

我有一个用于单层结构的工作宏,但我不确定如何修改它以支持嵌套结构。这是我所在的地方...

use proc_macro2::TokenStream;
use quote::quote;

use syn::{Data, DataStruct, DeriveInput, Fields, Type, TypePath};

pub fn impl_get_item(input: DeriveInput) -> syn::Result<TokenStream> {
    let model_name = input.ident;

    let fields = match input.data {
        Data::Struct(DataStruct {
            fields: Fields::Named(fields),
            ..
        }) => fields.named,
        _ => panic!("The GetItem derive can only be applied to structs"),
    };

    let mut matches = TokenStream::new();
    let mut item_index: u16 = 0;
    for field in fields {
        let item_name = field.ident;
        let item_type = field.ty;
        let ts = match item_type {
            Type::Path(TypePath { path, .. }) if path.is_ident("u8") => {
                quote! {#item_index => Ok(Item::U8(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i8") => {
                quote! {#item_index => Ok(Item::I8(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u16") => {
                quote! {#item_index => Ok(Item::U16(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i16") => {
                quote! {#item_index => Ok(Item::I16(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u32") => {
                quote! {#item_index => Ok(Item::U32(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i32") => {
                quote! {#item_index => Ok(Item::I32(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("u64") => {
                quote! {#item_index => Ok(Item::U64(self.#item_name)),}
            }
            Type::Path(TypePath { path, .. }) if path.is_ident("i64") => {
                quote! {#item_index => Ok(Item::I64(self.#item_name)),}
            }
            _ => panic!("{:?} uses unsupported type {:?}", item_name, item_type),
        };
        matches.extend(ts);
        item_index += 1;
    }

    let output = quote! {
        #[automatically_derived]
        impl #model_name {
            pub fn get(&self, index: u16) -> Result<Item, Error> {
                match index {
                    #matches
                    _ => Err(Error::BadIndex),
                }
            }
        }
    };

    Ok(output)
}

我不会给出完整的答案,因为我的 proc-macro 技能是不存在的,但我认为一旦结构正确,宏部分就不会很棘手.

我采用的方法是定义一个所有类型都将使用的特征。我将其命名为 Indexible,这可能很糟糕。该特征的要点是提供 get 函数和该对象中包含的所有字段的计数。

trait Indexible {
    fn nfields(&self) -> usize;
    fn get(&self, idx:usize) -> Result<Item>;
}

我使用 fn nfields(&self) -> usize 而不是 fn nfields() -> usize 因为 &self 意味着我可以在向量和切片上使用它,可能还有一些其他类型(它也使下面的代码稍微更整洁)。

接下来你需要为你的基类型实现这个特征:

impl Indexible for u8 {
    fn nfields(&self) -> usize { 1 }
    fn get(&self, idx:usize) -> Result<Item> { Ok(Item::U8(*self)) }
}
...

生成所有这些可能是宏的一个很好的用途(但你正在谈论的 proc 宏)。

接下来,您需要为所需的类型生成这些:我的实现如下所示:

impl Indexible for NestedData {
    fn nfields(&self) -> usize {
        self.a.nfields() +
        self.b.nfields()
    }
    fn get(&self, idx:usize) -> Result<Item> {
        let idx = idx;
        
        // member a
        if idx < self.a.nfields() {
            return self.a.get(idx)
        }
        let idx = idx - self.a.nfields();
        
        // member b
        if idx < self.b.nfields() {
            return self.b.get(idx)
        }
        Err(())
    }
}

impl Indexible for Data {
    fn nfields(&self) -> usize {
        self.a.nfields() +
        self.b.nfields() +
        self.c.nfields()
    }

    fn get(&self, idx:usize) -> Result<Item> {
        let idx = idx;

        if idx < self.a.nfields() {
            return self.a.get(idx)
        }
        let idx = idx - self.a.nfields();

        if idx < self.b.nfields() {
            return self.b.get(idx)
        }
        let idx = idx - self.b.nfields();
        
        if idx < self.c.nfields() {
            return self.c.get(idx)
        }
        Err(())
    }
}

您可以看到完整的 运行 版本 in the playground

这些看起来可以很容易地由宏生成。

如果您想要更好地了解无法工作的类型的错误消息,您应该像这样将每个成员显式地视为可索引的:(self.a as Indexible).get(..).

这似乎不是特别有效,但编译器能够确定这些片段中的大部分是常量并将它们内联。例如使用rust 1.51 with -C opt-level=3,下面的函数

pub fn sum(data: &Data) -> usize {
    let mut sum = 0;
    for i in 0..data.nfields() {
        sum += match data.get(i) {
            Err(_) => panic!(),
            Ok(Item::U8(v)) => v as usize,
            Ok(Item::U16(v)) => v as usize,
            Ok(Item::I32(v)) => v as usize,
            Ok(Item::U64(v)) => v as usize,
            _ => panic!(),
        }
    }
    sum
}

编译成这个

example::sum:
        movsxd  rax, dword ptr [rdi + 8]
        movsxd  rcx, dword ptr [rdi + 12]
        movzx   edx, word ptr [rdi + 16]
        add     rax, qword ptr [rdi]
        add     rax, rdx
        add     rax, rcx
        ret

你可以在compiler explorer

中看到这个