将函数作为参数传递,其中内部函数输入是迭代器

Pass function as an argument where inner function input is an Iterator

这是 的后续问题。

我想将函数 compute_var_iter 作为参数传递给另一个函数,但我正在努力理解错误消息:

这个有效:

use num::Float;

fn compute_var_iter<I, T>(vals: I) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
{
    // online variance function
    let mut x = T::zero();
    let mut xsquare = T::zero();
    let mut len = T::zero();

    for val in vals {
        x += val;
        xsquare += val * val;
        len += T::one();
    }

    ((xsquare / len) - (x / len) * (x / len)) / (len - T::one()) * len
}

fn main() {
    let a = (1..10000001).map(|i| i as f64);
    let b = (1..10000001).map(|i| i as f64);

    dbg!(compute_var_iter(a.zip(b).map(|(a, b)| a * b)));
}

但是当我尝试这个时:

use num::Float;

fn compute_var_iter<I, T>(vals: I) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
{
    // online variance function
    let mut x = T::zero();
    let mut xsquare = T::zero();
    let mut len = T::zero();

    for val in vals {
        x += val;
        xsquare += val * val;
        len += T::one();
    }

    ((xsquare / len) - (x / len) * (x / len)) / (len - T::one()) * len
}

fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T
where
    I: Iterator<Item = T>,
    T: Float + std::ops::AddAssign,
    Fa: Fn(I) -> T
{
   aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
}

fn main() {
    let a: Vec<f64> = (1..10000001).map(|i| i as f64).collect();
    let b: Vec<f64> = (1..10000001).map(|i| i as f64).collect();
    
    dbg!(use_fun(compute_var_iter, &a, &b));
}


我收到错误:

error[E0308]: mismatched types
  --> src/main.rs:43:16
   |
37 | fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T
   |            - this type parameter
...
43 |     aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
   |                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `I`, found struct `Map`
   |
   = note: expected type parameter `I`
                      found struct `Map<Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, T>>, [closure@src/main.rs:43:54: 43:70]>`

error[E0282]: type annotations needed
  --> src/main.rs:54:10
   |
54 |     dbg!(use_fun(compute_var_iter, &a, &b));
   |          ^^^^^^^ cannot infer type for type parameter `I` declared on the function `use_fun`

类型注释应该如何工作才能得到这个运行?

fn use_fun<I, Fa, T>(aggregator: Fa, values: &[T], weights: &[T]) -> T

这个签名表明 I 是一个类型参数,这意味着 use_fun 调用者可以选择任何迭代器类型。

   aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))

但是在函数体中,你没有传递I类型的迭代器;您传递类型为 Map<Zip<std::slice::Iter<'_, T>, std::slice::Iter<'_, T>>, [closure@src/main.rs:43:54: 43:70]> 的迭代器,它是迭代器适配器链的类型以及您创建并传递的切片迭代器。

原则上,你需要做的是改变你的函数边界,要求提供的函数是一个接受特定类型的函数(因为不可能要求该函数对任何迭代器都是通用的)。不幸的是,目前无法使用稳定的 Rust 在 Rust 源代码中写下该类型(因为它包含闭包类型)。以下是一些解决方案:

  1. 使用不稳定的特性(仅在编译器的夜间构建中可用,在稳定版本中不可用),type_alias_impl_trait,为闭包命名:

    #![feature(type_alias_impl_trait)]
    
    type MyIter<'a, T: 'a> = impl Iterator<Item = T> + 'a;
    fn make_iter<'a, T>(values: &'a [T], weights: &'a [T]) -> MyIter<'a, T>
    where
        T: Float + std::ops::AddAssign + 'a,
    {
        values[0..10].iter().zip(weights).map(|(x, y)| *x * *y)
    }
    
    fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
    where
        T: Float + std::ops::AddAssign,
        Fa: Fn(MyIter<'a, T>) -> T,
    {
        aggregator(make_iter(values, weights))
    }
    

    make_iter 辅助函数对于让 type ... = impl ... 知道它应该具有什么具体但不可命名的类型是必要的,方法是将 impl MyIterator 放在 return-类型位置。)

  2. 而不是使用 .map(closure),而是编写您自己的迭代器适配器(一个结构和一个 impl Iterator 用于它),因此它将具有可命名的类型。

  3. 使用函数指针将迭代器的类型声明为映射(这是一种可命名的类型,但如果优化器无法消除它,效率可能会稍低):

    type MyIter<'a, T> = std::iter::Map<
        std::iter::Zip<std::slice::Iter<'a, T>, std::slice::Iter<'a, T>>,
        fn((&T, &T)) -> T,
    >;
    
    fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
    where
        T: Float + std::ops::AddAssign,
        Fa: Fn(MyIter<'a, T>) -> T,
    {
        aggregator(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y))
    }
    
  4. 使用盒装类型擦除迭代器,Box<dyn Iterator>。 (这可能会比其他选项慢一些。)

    fn use_fun<'a, Fa, T>(aggregator: Fa, values: &'a [T], weights: &'a [T]) -> T
    where
        T: Float + std::ops::AddAssign,
        Fa: Fn(Box<dyn Iterator<Item = T> + 'a>) -> T,
    {
        aggregator(Box::new(values[0..10].iter().zip(weights).map(|(x, y)| *x * *y)))
    }
    
  5. 将集合而不是迭代器传递给 aggregator 函数。

  6. 不是将迭代器传递给 aggregator,而是让聚合器以“推送”方式接受数据:也许它可以实现 Extend 特征,或者只是一个FnMut 用每个值调用。这避免了聚合器接受 use_fn 选择的类型的需要——它只需要知道数字类型。