惰性状态转换器在 2D 递归中急切地消耗惰性列表

Lazy state transformer consumes lazy list eagerly in 2D recursion

我正在使用状态变换器在 2D 递归遍历的每个点随机采样数据集,它输出一个二维样本网格列表,这些样本共同满足一个条件。我想懒惰地从结果中提取结果,但我的方法反而在提取第一个结果之前在每个点都耗尽了整个数据集。

具体来说,考虑这个程序:

import Control.Monad ( sequence, liftM2 )
import Data.Functor.Identity
import Control.Monad.State.Lazy ( StateT(..), State(..), runState )

walk :: Int -> Int -> [State Int [Int]]
walk _ 0 = [return [0]]
walk 0 _ = [return [0]]
walk x y =
  let st :: [State Int Int]
      st = [StateT (\s -> Identity (s, s + 1)), undefined]
      unst :: [State Int Int] -- degenerate state tf
      unst = [return 1, undefined]
  in map (\m_z -> do
      z <- m_z
      fmap concat $ sequence [
          liftM2 (zipWith (\x y -> x + y + z)) a b -- for 1D: map (+z) <$> a
          | a <- walk x (y - 1) -- depth
          , b <- walk (x - 1) y -- breadth -- comment out for 1D
        ]
    ) st -- vs. unst

main :: IO ()
main = do
  std <- getStdGen
  putStrLn $ show $ head $ fst $ (`runState` 0) $ head $ walk 2 2

程序遍历矩形网格从 (x, y)(0, 0) 并对所有结果求和,包括状态单子列表之一的值:非平凡变换器 st 读取和推进它们的状态,或普通的变形金刚 unst。有趣的是算法是否探索了 stunst.

的头部

在呈现的代码中,它抛出 undefined。我将此归因于我链接转换顺序的错误设计,特别是状态处理的问题,因为使用 unst 代替(即,将结果与状态转换解耦)确实会产生结果。然而,然后我发现一维递归即使使用状态变换器也能保持惰性(删除宽度步长 b <- walk... 并将 liftM2 块换成 fmap)。

如果我们trace (show (x, y)),我们也会看到它确实在触发之前遍历了整个网格:

$ cabal run
Build profile: -w ghc-8.6.5 -O1
...
(2,2)
(2,1)
(1,2)
(1,1)
(1,1)
sandbox: Prelude.undefined

我怀疑我对 sequence 的使用在这里有问题,但是由于 monad 的选择和 walk 的维度影响它的成功,我不能广泛地说 sequenceing转换本身就是严格性的来源。

是什么导致了 1D 和 2D 递归在严格性上的差异,我怎样才能达到我想要的惰性?

考虑以下简化示例:

import Control.Monad.State.Lazy

st :: [State Int Int]
st = [state (\s -> (s, s + 1)), undefined]

action1d = do
  a <- sequence st
  return $ map (2*) a

action2d = do
  a <- sequence st
  b <- sequence st
  return $ zipWith (+) a b

main :: IO ()
main = do
  print $ head $ evalState action1d 0
  print $ head $ evalState action2d 0

在这里,在 1D 和 2D 计算中,结果的头部明确地仅取决于输入的头部(仅 head a 对于 1D 动作以及 head ahead b 用于 2D 动作)。但是,在 2D 计算中,b(甚至只是它的头部)对当前状态存在 隐式 依赖,并且该状态取决于 [=91] 的评估=] a 的全部,而不仅仅是头部。

您的示例中有类似的依赖关系,尽管它被状态操作列表的使用所掩盖。

假设我们想要手动 运行 操作 walk22_head = head $ walk 2 2 并检查结果列表中的第一个整数:

main = print $ head $ evalState walk22_head

显式写入状态动作列表的元素st

st1, st2 :: State Int Int
st1 = state (\s -> (s, s+1))
st2 = undefined

我们可以将walk22_head写成:

walk22_head = do
  z <- st1
  a <- walk21_head
  b <- walk12_head
  return $ zipWith (\x y -> x + y + z) a b

请注意,这仅取决于定义的状态动作 st1 以及 walk 2 1walk 1 2 的头部。反过来,那些头可以写成:

walk21_head = do
  z <- st1
  a <- return [0] -- walk20_head
  b <- walk11_head
  return $ zipWith (\x y -> x + y + z) a b

walk12_head = do
  z <- st1
  a <- walk11_head
  b <- return [0] -- walk02_head
  return $ zipWith (\x y -> x + y + z) a b

同样,这些仅取决于定义的状态动作 st1walk 1 1 的头部。

现在,让我们试着写下walk11_head的定义:

walk11_head = do
  z <- st1
  a <- return [0]
  b <- return [0]
  return $ zipWith (\x y -> x + y + z) a b

这仅取决于定义的状态操作 st1,因此有了这些定义,如果我们 运行 main,我们将得到一个定义的答案:

> main
10

但这些定义并不准确!在每个 walk 1 2walk 2 1 中,头部动作都是 动作序列 ,从调用 walk11_head 的动作开始,但继续动作基于 walk11_tail。因此,更准确的定义是:

walk21_head = do
  z <- st1
  a <- return [0] -- walk20_head
  b <- walk11_head
  _ <- walk11_tail  -- side effect of the sequennce
  return $ zipWith (\x y -> x + y + z) a b

walk12_head = do
  z <- st1
  a <- walk11_head
  b <- return [0] -- walk02_head
  _ <- walk11_tail  -- side effect of the sequence
  return $ zipWith (\x y -> x + y + z) a b

与:

walk11_tail = do
  z <- undefined
  a <- return [0]
  b <- return [0]
  return [zipWith (\x y -> x + y + z) a b]

有了这些定义,运行单独使用 walk12_headwalk21_head 就没有问题了:

> head $ evalState walk12_head 0
1
> head $ evalState walk21_head 0
1

计算答案不需要此处的状态副作用,因此从不调用。但是,不可能按顺序 运行 它们:

> head $ evalState (walk12_head >> walk21_head) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_2.hs:41:8 in main:Main

因此,尝试 运行 main 失败的原因相同:

> main
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_2.hs:41:8 in main:Main

因为,在计算walk22_head时,甚至walk21_head计算的最开始都取决于walk12_head引发的状态副作用walk11_tail

您最初的 walk 定义与这些模型的行为方式相同:

> head $ evalState (head $ walk 1 2) 0
1
> head $ evalState (head $ walk 2 1) 0
1
> head $ evalState (head (walk 1 2) >> head (walk 2 1)) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_0.hs:15:49 in main:Main
> head $ evalState (head (walk 2 2)) 0
*** Exception: Prelude.undefined
CallStack (from HasCallStack):
  error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
  undefined, called at Lazy2D_0.hs:15:49 in main:Main

很难说如何解决这个问题。您的玩具示例非常适合说明问题,但不清楚如何在 "real" 问题中使用状态以及 head $ walk 2 1 是否真的对 sequence 具有状态依赖性walk 1 1head $ walk 1 2.

引发的操作

是正确的:虽然在每个方向上迈出一步都很好(尝试 walkx < 2y < 2)的组合liftM2中的隐式>>=a中的sequence以及b中的状态依赖使得b依赖于所有a 的副作用。正如他还指出的那样,可行的解决方案取决于实际需要的依赖项。

我将针对我的特定情况分享一个解决方案:每个 walk 调用至少取决于调用者的状态,也许还有一些其他状态,基于网格的预序遍历和st 中的备选方案。此外,正如问题所暗示的那样,我想在 st 中测试任何不需要的替代方案之前尝试做出完整的结果。这在视觉上解释起来有点困难,但这是我能做的最好的事情:左边显示了每个坐标处 st 备选方案的可变数量(这是我在实际用例中所拥有的),右边显示了一个[相当混乱] 状态所需依赖顺序的映射:我们看到它在 3D DFS 中首先遍历 x-y,其中 "x" 作为深度(最快轴),"y" 作为宽度(中轴),然后最后选择最慢的轴(以带空心圆的虚线显示)。

最初实现的中心问题来自状态转换的排序列表,以适应非递归 return 类型。让我们用 monad 参数中递归的类型来替换列表类型,这样调用者可以更好地控制依赖顺序:

data ML m a = MCons a (MML m a) | MNil -- recursive monadic list
newtype MML m a = MML (m (ML m a)) -- base case wrapper

[1, 2]的例子:

MCons 1 (MML (return (MCons 2 (MML (return MNil)))))

Functor 和 Monoid 行为经常被使用,所以这里是相关的实现:

instance Functor m => Functor (ML m) where
  fmap f (MCons a m) = MCons (f a) (MML $ (fmap f) <$> coerce m)
  fmap _ MNil = MNil

instance Monad m => Semigroup (MML m a) where
  (MML l) <> (MML r) = MML $ l >>= mapper where
    mapper (MCons la lm) = return $ MCons la (lm <> (MML r))
    mapper MNil = r

instance Monad m => Monoid (MML m a) where
  mempty = MML (pure MNil)

有两个关键操作:在两个不同的轴上合并步骤,以及在同一坐标上合并来自不同备选方案的列表。分别为:

  1. 根据图表,我们希望首先从 x 步骤获得单个完整结果,然后从 y 步骤获得完整结果。每个步骤 return 都是来自内部坐标的可行备选方案的所有组合的结果列表,因此我们对两个列表进行笛卡尔积,也偏向一个方向(在本例中 y 最快)。首先,我们定义一个 "concatenation",它在裸列表 ML:

    的末尾应用基本案例包装器 MML
    nest :: Functor m => MML m a -> ML m a -> ML m a
    nest ma (MCons a mb) = MCons a (MML $ nest ma <$> coerce mb)
    

    然后笛卡尔积:

    prodML :: Monad m => (a -> a -> a) -> ML m a -> ML m a -> ML m a
    prodML f x (MCons ya ym) = (MML $ prodML f x <$> coerce ym) `nest` ((f ya) <$> x)
    prodML _ MNil _ = MNil
    
  2. 我们想将来自不同备选方案的列表粉碎成一个列表,我们不关心这会引入备选方案之间的依赖关系。这是我们从 Monoid 实例中使用 mconcat 的地方。

总而言之,它看起来像这样:

walk :: Int -> Int -> MML (State Int) Int
-- base cases
walk _ 0 = MML $ return $ MCons 1 (MML $ return MNil)
walk 0 _ = walk 0 0

walk x y =
  let st :: [State Int Int]
      st = [StateT (\s -> Identity (s, s + 1)), undefined]
      xstep = coerce $ walk (x-1) y
      ystep = coerce $ walk x (y-1)
     -- point 2: smash lists with mconcat
  in mconcat $ map (\mz -> MML $ do
      z <- mz
                              -- point 1: product over results
      liftM2 ((fmap (z+) .) . prodML (+)) xstep ystep
    ) st

headML (MCons a _) = a
headML _ = undefined

main :: IO ()
main = putStrLn $ show $ headML $ fst $ (`runState` 0) $ (\(MML m) -> m) $ walk 2 2

请注意结果已随语义发生变化。这对我来说无关紧要,因为我的目标只需要从状态中提取随机数,并且可以通过将列表元素正确引导到最终结果中来控制所需的任何依赖顺序。

(我还要警告,如果没有记忆或注意严格性,这个实现对于大的 x 和 y 是非常低效的。)