在 haskell 中记忆多维递归解决方案
Memoize multi-dimensional recursive solutions in haskell
我正在解决 haskell 中的一个递归问题,尽管我可以得到解决方案 我想缓存子问题的输出,因为有重叠子问题 属性.
问题是,给定一个维度为 n*m
的网格和一个整数 k
,从 (1, 1) 到达网格 (n, m) 有多少种方式不超过 k 个方向改变?
这是没有记忆的代码
paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
| i > n || j > m || k < 0 = 0
| i == n && j == m = 1
| dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2 -- is in grid (1,1)
| dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2 -- down was the direction took to reach here
| dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2 -- right was the direction took to reach here
| otherwise = -1
这里的因变量是i
、j
、k
、dir
。在 C++/Java 等语言中,可以使用 4 维 DP 数组(dp[n][m][k][3]
,在 Haskell 中,我找不到实现它的方法。
在 Haskell 中,这些事情确实不是最微不足道的事情。您真的希望进行一些就地突变以节省内存和时间,所以我认为没有比装备可怕的 ST
monad 更好的方法了。
这可以在各种数据结构、数组、向量上完成,repa tensors. I chose HashTable
from hashtables因为它使用起来最简单,而且性能足以在我的示例中发挥作用。
首先介绍一下:
{-# LANGUAGE Rank2Types #-}
module Solution where
import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT
Rank2Types
在处理 ST
时很有用,因为幻像类型。我选择了哈希表的 Basic
变体,因为作者声称它的查找速度最快 --- 我们将进行大量查找。
建议为地图使用类型别名,所以我们开始:
type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer
ST-free 入口点只是为了创建地图并调用我们的怪物:
runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
mem <- HT.new
paths mem i j n m k dir
这里是paths
的记忆计算。我们只是尝试在地图中搜索结果,如果不存在,我们将其保存并 return:
mempaths mem i j n m k dir = do
res <- HT.lookup mem (i, j, k, dir)
case res of
Just x -> return x
Nothing -> do
x <- paths mem i j n m k dir
HT.insert mem (i, j, k, dir) x
return x
这里是算法的大脑。它只是一个使用记忆调用代替普通递归的单子动作:
paths mem i j n m k dir
| i > n || j > m || k < 0 = return 0
| i == n && j == m = return 1
| dir == 0 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m k 2 -- is in grid (1,1)
return $ x1 + x2
| dir == 1 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m (k-1) 2 -- down was the direction took to reach here
return $ x1 + x2
| dir == 2 = do
x1 <- mempaths mem (i+1) j n m (k-1) 1
x2 <- mempaths mem i (j+1) n m k 2 -- right was the direction took to reach here
return $ x1 + x2
| otherwise = return (-1)
“喜结连理”是一种让 GHC 运行时为您记住结果的众所周知的技术,如果您提前知道您将需要查找的所有值。这个想法是把你的递归函数变成一个自引用的数据结构,然后简单地查找你真正关心的值。为此,我选择使用 Array,但 Map 也可以。在任何一种情况下,您使用的数组或映射都必须是 lazy/non-strict,因为我们将向其中插入值,直到整个数组被填满我们才准备好计算。
import Data.Array (array, bounds, inRange, (!))
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
where go (i, j, k, dir)
| i == m && j == n = 1
| dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2) -- down was the direction took to reach here
| dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2) -- right was the direction took to reach here
| otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2) -- is in grid (1,1)
a = array ((1, 1, 0, 1), (m, n, k, 2))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
我稍微简化了你的API:
m
和 n
参数不会随每次迭代而改变,因此它们不应成为递归调用的一部分
- 客户端不必告诉您
i
、j
和 dir
的开头,因此它们已从函数签名中删除并隐式地从分别为 1、1 和 0
- 我还交换了
m
和 n
的顺序,因为先取一个 n
参数很奇怪。这让我有点头疼,因为我有一段时间没有注意到我也需要改变基本情况!
然后,正如我之前所说,我们的想法是用我们需要进行的所有递归调用填充数组:这就是 array
调用。请注意 array
中的单元格是通过调用 go
初始化的,这(基本情况除外!)涉及调用 get
,这涉及查找数组中的元素。这样,a
就是自指或递归的。但是我们不必决定以什么顺序查找东西,或者以什么顺序插入它们:我们足够懒惰,GHC 根据需要计算数组元素。
我也有点厚脸皮,只在 dir=1
和 dir=2
的数组中制作 space,而不是 dir=0
。我逃脱了这个,因为 dir=0
只发生在第一次调用时,我可以直接调用 go
,绕过 get
中的边界检查。如果传递小于 1 的 m
或 n
,或者小于零的 k
,这个技巧确实意味着您将收到运行时错误。如果你需要处理这种情况,你可以为 paths
本身添加一个守卫。
当然,它确实有效:
> paths 3 3 2
4
您可以做的另一件事是为您的方向使用真实的数据类型,而不是 Int
:
import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)
data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
where go (i, j, k, dir)
| i == m && j == n = 1
| otherwise = case dir of
Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
a = array ((1, 1, 0, Down), (m, n, k, Right))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
(I 和 J 可能比 Down 和 Right 更好,我不知道那个更容易记住还是更难记住)。我认为这是 可能 的改进,因为类型现在有了更多的意义,而且你没有这个奇怪的 otherwise
子句来处理 dir=7
这样的事情应该是非法的。但它仍然有点不稳定,因为它依赖于枚举值的排序:如果我们将 Neutral
放在 Down
和 Right
之间,它会中断。 (我尝试完全删除 Neutral
方向并为第一步添加更多特殊外壳,但这以其自身的方式变得丑陋)
我正在解决 haskell 中的一个递归问题,尽管我可以得到解决方案 我想缓存子问题的输出,因为有重叠子问题 属性.
问题是,给定一个维度为 n*m
的网格和一个整数 k
,从 (1, 1) 到达网格 (n, m) 有多少种方式不超过 k 个方向改变?
这是没有记忆的代码
paths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
paths i j n m k dir
| i > n || j > m || k < 0 = 0
| i == n && j == m = 1
| dir == 0 = paths (i+1) j n m k 1 + paths i (j+1) n m k 2 -- is in grid (1,1)
| dir == 1 = paths (i+1) j n m k 1 + paths i (j+1) n m (k-1) 2 -- down was the direction took to reach here
| dir == 2 = paths (i+1) j n m (k-1) 1 + paths i (j+1) n m k 2 -- right was the direction took to reach here
| otherwise = -1
这里的因变量是i
、j
、k
、dir
。在 C++/Java 等语言中,可以使用 4 维 DP 数组(dp[n][m][k][3]
,在 Haskell 中,我找不到实现它的方法。
在 Haskell 中,这些事情确实不是最微不足道的事情。您真的希望进行一些就地突变以节省内存和时间,所以我认为没有比装备可怕的 ST
monad 更好的方法了。
这可以在各种数据结构、数组、向量上完成,repa tensors. I chose HashTable
from hashtables因为它使用起来最简单,而且性能足以在我的示例中发挥作用。
首先介绍一下:
{-# LANGUAGE Rank2Types #-}
module Solution where
import Control.Monad.ST
import Control.Monad
import Data.HashTable.ST.Basic as HT
Rank2Types
在处理 ST
时很有用,因为幻像类型。我选择了哈希表的 Basic
变体,因为作者声称它的查找速度最快 --- 我们将进行大量查找。
建议为地图使用类型别名,所以我们开始:
type Mem s = HT.HashTable s (Int, Int, Int, Int) Integer
ST-free 入口点只是为了创建地图并调用我们的怪物:
runpaths :: Int -> Int -> Int -> Int -> Int -> Int -> Integer
runpaths i j n m k dir = runST $ do
mem <- HT.new
paths mem i j n m k dir
这里是paths
的记忆计算。我们只是尝试在地图中搜索结果,如果不存在,我们将其保存并 return:
mempaths mem i j n m k dir = do
res <- HT.lookup mem (i, j, k, dir)
case res of
Just x -> return x
Nothing -> do
x <- paths mem i j n m k dir
HT.insert mem (i, j, k, dir) x
return x
这里是算法的大脑。它只是一个使用记忆调用代替普通递归的单子动作:
paths mem i j n m k dir
| i > n || j > m || k < 0 = return 0
| i == n && j == m = return 1
| dir == 0 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m k 2 -- is in grid (1,1)
return $ x1 + x2
| dir == 1 = do
x1 <- mempaths mem (i+1) j n m k 1
x2 <- mempaths mem i (j+1) n m (k-1) 2 -- down was the direction took to reach here
return $ x1 + x2
| dir == 2 = do
x1 <- mempaths mem (i+1) j n m (k-1) 1
x2 <- mempaths mem i (j+1) n m k 2 -- right was the direction took to reach here
return $ x1 + x2
| otherwise = return (-1)
“喜结连理”是一种让 GHC 运行时为您记住结果的众所周知的技术,如果您提前知道您将需要查找的所有值。这个想法是把你的递归函数变成一个自引用的数据结构,然后简单地查找你真正关心的值。为此,我选择使用 Array,但 Map 也可以。在任何一种情况下,您使用的数组或映射都必须是 lazy/non-strict,因为我们将向其中插入值,直到整个数组被填满我们才准备好计算。
import Data.Array (array, bounds, inRange, (!))
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, 0)
where go (i, j, k, dir)
| i == m && j == n = 1
| dir == 1 = get (i+1, j, k, 1) + get (i, j+1, k-1, 2) -- down was the direction took to reach here
| dir == 2 = get (i+1, j, k-1, 1) + get (i, j+1, k, 2) -- right was the direction took to reach here
| otherwise = get (i+1, j, k, 1) + get (i, j+1, k, 2) -- is in grid (1,1)
a = array ((1, 1, 0, 1), (m, n, k, 2))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [1..2]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
我稍微简化了你的API:
m
和n
参数不会随每次迭代而改变,因此它们不应成为递归调用的一部分- 客户端不必告诉您
i
、j
和dir
的开头,因此它们已从函数签名中删除并隐式地从分别为 1、1 和 0 - 我还交换了
m
和n
的顺序,因为先取一个n
参数很奇怪。这让我有点头疼,因为我有一段时间没有注意到我也需要改变基本情况!
然后,正如我之前所说,我们的想法是用我们需要进行的所有递归调用填充数组:这就是 array
调用。请注意 array
中的单元格是通过调用 go
初始化的,这(基本情况除外!)涉及调用 get
,这涉及查找数组中的元素。这样,a
就是自指或递归的。但是我们不必决定以什么顺序查找东西,或者以什么顺序插入它们:我们足够懒惰,GHC 根据需要计算数组元素。
我也有点厚脸皮,只在 dir=1
和 dir=2
的数组中制作 space,而不是 dir=0
。我逃脱了这个,因为 dir=0
只发生在第一次调用时,我可以直接调用 go
,绕过 get
中的边界检查。如果传递小于 1 的 m
或 n
,或者小于零的 k
,这个技巧确实意味着您将收到运行时错误。如果你需要处理这种情况,你可以为 paths
本身添加一个守卫。
当然,它确实有效:
> paths 3 3 2
4
您可以做的另一件事是为您的方向使用真实的数据类型,而不是 Int
:
import Data.Array (Ix, array, bounds, inRange, (!))
import Prelude hiding (Right)
data Direction = Neutral | Down | Right deriving (Eq, Ord, Ix)
paths :: Int -> Int -> Int -> Integer
paths m n k = go (1, 1, k, Neutral)
where go (i, j, k, dir)
| i == m && j == n = 1
| otherwise = case dir of
Neutral -> get (i+1, j, k, Down) + get (i, j+1, k, Right)
Down -> get (i+1, j, k, Down) + get (i, j+1, k-1, Right)
Right -> get (i+1, j, k-1, Down) + get (i, j+1, k, Right)
a = array ((1, 1, 0, Down), (m, n, k, Right))
[(c, go c) | c <- (,,,) <$> [1..m] <*> [1..n] <*> [0..k] <*> [Down, Right]]
get x | inRange (bounds a) x = a ! x
| otherwise = 0
(I 和 J 可能比 Down 和 Right 更好,我不知道那个更容易记住还是更难记住)。我认为这是 可能 的改进,因为类型现在有了更多的意义,而且你没有这个奇怪的 otherwise
子句来处理 dir=7
这样的事情应该是非法的。但它仍然有点不稳定,因为它依赖于枚举值的排序:如果我们将 Neutral
放在 Down
和 Right
之间,它会中断。 (我尝试完全删除 Neutral
方向并为第一步添加更多特殊外壳,但这以其自身的方式变得丑陋)