在 haskell 中记忆多维递归解决方案

Memoize multi-dimensional recursive solutions in haskell

我正在解决 haskell 中的一个递归问题,尽管我可以得到解决方案 我想缓存子问题的输出,因为有重叠子问题 属性.

问题是,给定一个维度为 n*m 的网格和一个整数 k,从 (1, 1) 到达网格 (n, m) 有多少种方式不超过 k 个方向改变?

这是没有记忆的代码

paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
    | i > n || j > m || k < 0 = 0
    | i == n && j == m = 1
    | dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2        -- is in grid (1,1)
    | dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2    -- down was the direction took to reach here
    | dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2    -- right was the direction took to reach here 
    | otherwise = -1

这里的因变量是ijkdir。在 C++/Java 等语言中,可以使用 4 维 DP 数组(dp[n][m][k][3],在 Haskell 中,我找不到实现它的方法。

在 Haskell 中,这些事情确实不是最微不足道的事情。您真的希望进行一些就地突变以节省内存和时间,所以我认为没有比装备可怕的 ST monad 更好的方法了。

这可以在各种数据结构、数组、向量上完成,repa tensors. I chose HashTable from hashtables因为它使用起来最简单,而且性能足以在我的示例中发挥作用。


首先介绍一下:

{-# LANGUAGE Rank2Types #-}
module Solution where

import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT

Rank2Types 在处理 ST 时很有用,因为幻像类型。我选择了哈希表的 Basic 变体,因为作者声称它的查找速度最快 --- 我们将进行大量查找。

建议为地图使用类型别名,所以我们开始:

type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer

ST-free 入口点只是为了创建地图并调用我们的怪物:

runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
  mem <- HT.new
  paths mem i j n m k dir

这里是paths的记忆计算。我们只是尝试在地图中搜索结果,如果不存在,我们将其保存并 return:

mempaths mem i j n m k dir = do
  res <- HT.lookup mem (i, j, k, dir)
  case res of
    Just x -> return x
    Nothing -> do
      x <- paths mem i j n m k dir
      HT.insert mem (i, j, k, dir) x
      return x

这里是算法的大脑。它只是一个使用记忆调用代替普通递归的单子动作:

paths mem i j n m k dir
    | i > n || j > m || k < 0 = return 0
    | i == n && j == m = return 1
    | dir == 0 = do
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m k 2        -- is in grid (1,1)
        return $ x1 + x2
    | dir == 1 = do 
        x1 <- mempaths mem (i+1) j n m k 1
        x2 <- mempaths mem i (j+1) n m (k-1) 2    -- down was the direction took to reach here
        return $ x1 + x2
    | dir == 2 = do
        x1 <- mempaths mem (i+1) j n m (k-1) 1 
        x2 <- mempaths mem i (j+1) n m k 2    -- right was the direction took to reach here 
        return $ x1 + x2
    | otherwise = return (-1)

“喜结连理”是一种让 GHC 运行时为您记住结果的众所周知的技术,如果您提前知道您将需要查找的所有值。这个想法是把你的递归函数变成一个自引用的数据结构,然后简单地查找你真正关心的值。为此,我选择使用 Array,但 Map 也可以。在任何一种情况下,您使用的数组或映射都必须是 lazy/non-strict,因为我们将向其中插入值,直到整个数组被填满我们才准备好计算。

import Data.Array (array, bounds, inRange, (!))

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2)    -- down was the direction took to reach here
          | dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2)    -- right was the direction took to reach here
          | otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2)     -- is in grid (1,1)
        a = array ((1, 1, 0, 1), (m, n, k, 2))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

我稍微简化了你的API:

  • mn 参数不会随每次迭代而改变,因此它们不应成为递归调用的一部分
  • 客户端不必告诉您 ijdir 的开头,因此它们已从函数签名中删除并隐式地从分别为 1、1 和 0
  • 我还交换了 mn 的顺序,因为先取一个 n 参数很奇怪。这让我有点头疼,因为我有一段时间没有注意到我也需要改变基本情况!

然后,正如我之前所说,我们的想法是用我们需要进行的所有递归调用填充数组:这就是 array 调用。请注意 array 中的单元格是通过调用 go 初始化的,这(基本情况除外!)涉及调用 get,这涉及查找数组中的元素。这样,a 就是自指或递归的。但是我们不必决定以什么顺序查找东西,或者以什么顺序插入它们:我们足够懒惰,GHC 根据需要计算数组元素。

我也有点厚脸皮,只在 dir=1dir=2 的数组中制作 space,而不是 dir=0。我逃脱了这个,因为 dir=0 只发生在第一次调用时,我可以直接调用 go,绕过 get 中的边界检查。如果传递小于 1 的 mn,或者小于零的 k,这个技巧确实意味着您将收到运行时错误。如果你需要处理这种情况,你可以为 paths 本身添加一个守卫。

当然,它确实有效:

> paths 3 3 2
4

您可以做的另一件事是为您的方向使用真实的数据类型,而不是 Int:

import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)

data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)

paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
  where go (i, j, k, dir)
          | i == m && j == n = 1
          | otherwise = case dir of
            Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
            Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
            Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
        a = array ((1, 1, 0, Down), (m, n, k, Right))
            [(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
        get x | inRange (bounds a) x = a ! x
              | otherwise = 0

(I 和 J 可能比 Down 和 Right 更好,我不知道那个更容易记住还是更难记住)。我认为这是 可能 的改进,因为类型现在有了更多的意义,而且你没有这个奇怪的 otherwise 子句来处理 dir=7 这样的事情应该是非法的。但它仍然有点不稳定,因为它依赖于枚举值的排序:如果我们将 Neutral 放在 DownRight 之间,它会中断。 (我尝试完全删除 Neutral 方向并为第一步添加更多特殊外壳,但这以其自身的方式变得丑陋)