迭代 State Monad 并以良好的性能按顺序收集结果

Iterate State Monad and Collect Results in Sequence with Good Performance

我实现了以下功能:

iterateState :: Int -> (a -> State s a) -> (a -> State s [a])
iterateState 0 f a = return []
iterateState n f a = do
    b <- f a
    xs <- iterateState (n - 1) f b
    return $ b : xs

我的主要用例是 a = Double。它有效,但速度很慢。它分配 528MB 的堆 space 来生成 1M Double 值的列表,并将大部分时间用于垃圾收集。

我已经尝试过直接在类型 s -> (a, s) 上工作的实现以及各种严格性注释。我能够稍微减少堆分配,但甚至没有接近人们对合理实现的期望。我怀疑由此产生的 ([a], s) 是懒惰消耗的东西 ([a]) 和 WHNF 强制整个计算的东西 (s) 的组合,这使得 GHC 的优化变得困难。

假设列表的迭代性质不适合这种情况,我转向 vector 包。令我高兴的是,它已经包含

iterateNM :: (Monad m, Unbox a) => Int -> (a -> m a) -> a -> m (Vector a)

不幸的是,这只比我的列表实现稍微快一点,仍然分配 328MB 的堆 space。我假设这是因为它使用 unstreamM,其描述为

Load monadic stream bundle into a newly allocated vector. This function goes through a list, so prefer using unstream, unless you need to be in a monad.

查看其对列表 monad 的行为,可以理解一般 monad 没有有效的实现。幸运的是,我只需要 state monad,并且我找到了另一个 almost 符合 state monad 签名的函数。

unfoldrExactN :: Unbox a => Int -> (b -> (a, b)) -> b -> Vector a

此函数速度极快,并且不会执行超出保存 1M Double 值的结果未装箱向量所需的 8MB 的额外堆分配。不幸的是,它不是 return 计算结束时的最终状态,因此它不能包装在 State 类型中。

我查看了 unfoldrExactN 的实现,看看是否可以调整它以在计算结束时公开最终状态。不幸的是,这似乎很难,因为

unfoldrExactN :: Monad m => Int -> (s -> (a, s)) -> s -> Stream m a

最终被unstream展开成向量已经忘记了状态类型s

我想我可以绕过整个 Stream 基础设施并直接在 ST monad 中的可变向量上实现 iterateState(类似于 unstream 将流扩展为向量)。但是,我会失去流融合的所有好处,并且出于性能原因将很容易表示为纯函数的计算变成命令式的低级糊状物。当知道现有的 unfoldrExactN 已经计算出我想要的所有值,但我无法访问它们时,这尤其令人沮丧。

有没有更好的方法?

这个功能能否以纯函数的方式实现,性能合理,没有多余的堆分配?最好以与 vector 包及其流融合基础设施相关的方式。

以下程序在使用优化编译时在我的计算机上的最大驻留空间为 12MB:

import Data.Vector.Unboxed
import Data.Vector.Unboxed.Mutable

iterateNState :: Unbox a => Int -> (a -> s -> (s, a)) -> (a -> s -> (s, Vector a))
iterateNState n f a0 s0 = createT (unsafeNew n >>= go 0 a0 s0) where
    go i a s arr
        | i >= n = pure (s, arr)
        | otherwise = do
            unsafeWrite arr i a
            case f a s of
                (s', a') -> go (i+1) a' s' arr

main = id
    . print
    . Data.Vector.Unboxed.sum
    . snd
    $ iterateNState 1000000 (\a s -> (s+1, a+s :: Int)) 0 0

(即使从输入中动态读取最后两个 0,它仍然保持良好的低驻留。)