记忆类型 [Integer] -> a 的函数

Memoizing a function of type [Integer] -> a

我的问题是如何有效地记忆一个昂贵的函数 f :: [Integer] -> a,它是为所有 finite 整数列表定义的,并且具有属性 f . sort = f?

我的典型用例是给定一个 as 整数列表,我需要获取各种整数 a 的值 f (a:as),所以我想同时建立一个有向标记图其顶点是一对整数列表及其函数值。当且仅当 a:as = bs.

时存在标记为从 (as, f as) 到 (bs, f bs) 的边

盗用 brilliant answer by Edward Kmett 我只是复制了

{-# LANGUAGE BangPatterns #-}
data Tree a = Tree (Tree a) a (Tree a)
instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

并将他的想法应用到我的问题中

-- directed graph labelled by Integers
data Graph a = Graph a (Tree (Graph a))
instance Functor Graph where
  fmap f (Graph a t) = Graph (f a) (fmap (fmap f) t)

-- walk the graph following the given labels
walk :: Graph a -> [Integer] -> a
walk (Graph a _) [] = a
walk (Graph _ t) (x:xs) = walk (index t x) xs

-- graph of all finite integer sequences
intSeq :: Graph [Integer]
intSeq = Graph [] (fmap (\n -> fmap (n:) intSeq) nats)

-- could be replaced by Data.Strict.Pair
data StrictPair a b = StrictPair !a !b
  deriving Show

-- f = sum modified according to Edward's idea (the real function is more complicated)
g :: ([Integer] -> StrictPair Integer [Integer]) -> [Integer] -> StrictPair Integer [Integer]
g mf [] = StrictPair 0 []
g mf (a:as) = StrictPair (a+x) (a:as)
  where StrictPair x y = mf as

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) intSeq

g_m :: [Integer] -> StrictPair Integer [Integer]
g_m = walk g_graph

这可以正常工作,但由于函数 f 与出现的整数的顺序无关(但与它们的计数无关),对于所有等于的整数列表,图中应该只有一个顶点订购。

我该如何实现?

如何定义 g_m' = g_m . sort,即在调用记忆函数之前先对输入列表进行排序?

我觉得这是你能做的最好的事情,因为如果你想让你的记忆图只包含排序的路径,那么有人将不得不在构建路径之前查看列表的所有元素。

根据您的输入列表的外观,以减少树分支的方式转换它们可能会有所帮助。例如,您可以尝试排序和取差:

original input list:   [8,3,14,8,5]
sorted:                [3,3,8,8,14]
diffed:                [3,0,5,0,6] -- use this as the key

转换是一个双射,树分支较少,因为涉及的数字较小。

您可以使用稍微不同的方法。 有一个技巧可以证明可数集的有限积是可数的:

我们可以通过product . zipWith (^) primes将序列[a1, ..., an]映射到Nat2 ^ a1 * 3 ^ a2 * 5 ^ a3 * ... * primen ^ an.

为避免末尾为零的序列出现问题,我们可以增加最后一个索引。

由于序列是有序的,我们可以利用 提到的 属性。

使用树的好处是可以增加分支以加快遍历。 OTOH prime 技巧可能会使索引变得非常大,但希望一些树路径将未被探索(保持为 thunks)。

{-# LANGUAGE BangPatterns #-}

-- Modified from Kmett's answer:
data Tree a = Tree a (Tree a) (Tree a) (Tree a) (Tree a)
instance Functor Tree where
  fmap f (Tree x a b c d) = Tree (f x) (fmap f a) (fmap f b) (fmap f c) (fmap f d)

index :: Tree a -> Integer -> a
index (Tree x _ _ _ _) 0 = x
index (Tree _ a b c d) n = case (n - 1) `divMod` 4 of
  (q,0) -> index a q
  (q,1) -> index b q
  (q,2) -> index c q
  (q,3) -> index d q

nats :: Tree Integer
nats = go 0 1
    where
        go !n !s = Tree n (go a s') (go b s') (go c s') (go d s')
            where
                a = n + s
                b = a + s
                c = b + s
                d = c + s
                s' = s * 4

toList :: Tree a -> [a]
toList as = map (index as) [0..]

-- Primes -- https://www.haskell.org/haskellwiki/Prime_numbers
-- Generation and factorisation could be done much better
minus (x:xs) (y:ys) = case (compare x y) of
           LT -> x : minus  xs  (y:ys)
           EQ ->     minus  xs     ys
           GT ->     minus (x:xs)  ys
minus  xs     _     = xs

primes = 2 : sieve [3..] primes
  where
    sieve xs (p:ps) | q <- p*p , (h,t) <- span (< q) xs =
                   h ++ sieve (t `minus` [q, q+p..]) ps

addToLast :: [Integer] -> [Integer]
addToLast [] = []
addToLast [x] = [x + 1]
addToLast (x:xs) = x : addToLast xs

subFromLast :: [Integer] -> [Integer]
subFromLast [] = []
subFromLast [x] = [x - 1]
subFromLast (x:xs) = x : subFromLast xs

addSubProp :: [NonNegative Integer] -> Property
addSubProp xs = xs' === subFromLast (addToLast xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

listToInteger :: [Integer] -> Integer
listToInteger = product . zipWith (^) primes . addToLast

integerToList :: Integer -> [Integer]
integerToList = subFromLast . impl primes 0
  where impl _      _ 0 = []
        impl _      0 1 = []
        impl _      k 1 = [k]
        impl (p:ps) k n = case n `divMod` p of
                            (n', 0) -> impl (p:ps) (k + 1) n'
                            (_,  _) -> k : impl ps 0 n

listProp :: [NonNegative Integer] -> Property
listProp xs = xs' === integerToList (listToInteger xs')
  where xs' = map getNonNegative xs

toIndex :: [Integer] -> Integer
toIndex = listToInteger . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerToList

-- [1,0] /= [0]
-- Decreasing sequence!
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

在我的机器上 fix g [1..22] 感觉很慢,而 faster_g [1..40] 仍然非常快。


加法: if 我们有界集(索引为0..n-1) ,我们可以将其编码为:a0 * n^0 + a1 * n^1 ....

我们可以将任何 Integer 编码为二进制列表,例如11[1, 1, 0, 1](最低位优先)。 然后,如果我们用 2 分隔列表中的整数,我们将得到有界值序列。

作为奖励,我们可以将 0、1、2 数字的序列 压缩 为二进制,例如使用霍夫曼编码,因为 2 比 0 或 1 少得多。但这可能有点矫枉过正。

有了这个技巧,索引会变得更小,space 可能会更好地打包。

{-# LANGUAGE BangPatterns #-}

-- From Kment's answer:
import Data.Function (fix)
import Data.List (sort, tails)
import Data.List.Split (splitOn)
import Test.QuickCheck

{-- Tree definition as before --}

-- 0, 1, 2
newtype N3 = N3 { unN3 :: Integer }
  deriving (Eq, Show)

instance Arbitrary N3 where
  arbitrary = elements $ map N3 [ 0, 1, 2 ]

-- Integer <-> N3
coeffs3 :: [Integer]
coeffs3 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 3)

listToInteger :: [N3] -> Integer
listToInteger = sum . zipWith f coeffs3
  where f n (N3 m) = n * m

listFromInteger :: Integer -> [N3]
listFromInteger 0 = []
listFromInteger n = case n `divMod` 3 of
  (q, m) -> N3 m : listFromInteger q

listProp :: [N3] -> Property
listProp xs = (null xs || last xs /= N3 0) ==> xs === listFromInteger (listToInteger xs)

-- Integer <-> N2

-- 0, 1
newtype N2 = N2 { unN2 :: Integer }
  deriving (Eq, Show)

coeffs2 :: [Integer]
coeffs2 = coeffs' 1
  where coeffs' n = n : coeffs' (n * 2)

integerToBin :: Integer -> [N2]
integerToBin 0 = []
integerToBin n = case n `divMod` 2 of
  (q, m) -> N2 m : integerToBin q

integerFromBin :: [N2] -> Integer
integerFromBin = sum . zipWith f coeffs2
  where f n (N2 m) = n * m

binProp :: NonNegative Integer -> Property
binProp (NonNegative n) = n === integerFromBin (integerToBin n)

-- unsafe!
n3ton2 :: N3 -> N2
n3ton2 = N2 . unN3

n2ton3 :: N2 -> N3
n2ton3 = N3 . unN2

-- [Integer] <-> [N3]
integerListToN3List :: [Integer] -> [N3]
integerListToN3List = concatMap (++ [N3 2]) . map (map n2ton3 . integerToBin)

integerListFromN3List :: [N3] -> [Integer]
integerListFromN3List = init . map (integerFromBin . map n3ton2) . splitOn [N3 2]

n3ListProp :: [NonNegative Integer] -> Property
n3ListProp xs = xs' === integerListFromN3List (integerListToN3List xs')
  where xs' = map getNonNegative xs

-- Trick from user5402 answer
-- Integer <-> Sorted Integer
toDiffList :: [Integer] -> [Integer]
toDiffList = toDiffList' 0
  where toDiffList' _ [] = []
        toDiffList' p (x:xs) = x - p : toDiffList' x xs

fromDiffList :: [Integer] -> [Integer]
fromDiffList = fromDiffList' 0
  where fromDiffList' _ [] = []
        fromDiffList' p (x:xs) = p + x : fromDiffList' (x + p) xs

diffProp :: [Integer] -> Property
diffProp xs = xs === fromDiffList (toDiffList xs)

---

toIndex :: [Integer] -> Integer
toIndex = listToInteger . integerListToN3List . toDiffList

fromIndex :: Integer -> [Integer]
fromIndex = fromDiffList . integerListFromN3List . listFromInteger

-- [1,0] /= [0]
-- Decreasing sequence! doesn't terminate in this case
doesntHold :: [NonNegative Integer] -> Property
doesntHold xs = xs' === fromIndex (toIndex xs')
  where xs' = map getNonNegative xs

holds :: [NonNegative Integer] -> Property
holds xs = xs' === fromIndex (toIndex xs')
  where xs' = sort $ map getNonNegative xs

g :: ([Integer] -> Integer) -> [Integer] -> Integer
g mg = g' . sort
  where g' [] = 0
        g' (x:xs)  = x + sum (map mg $ tails xs)

g_tree :: Tree Integer
g_tree = fmap (g faster_g' . fromIndex) nats

faster_g' :: [Integer] -> Integer
faster_g' = index g_tree . toIndex

faster_g = faster_g' . sort

第二次加法:

我为我的 g 快速对图和二进制序列方法进行了基准测试:

main :: IO ()
main = do
  n <- read . head <$> getArgs
  print $ faster_g [100, 110..n]

结果是:

% time ./IntegerMemo 1000
1225560638892526472150132981770
./IntegerMemo 1000  0.19s user 0.01s system 98% cpu 0.200 total
% time ./IntegerMemo 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemo 2000  1.83s user 0.05s system 99% cpu 1.888 total
% time ./IntegerMemo 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemo 2500  3.74s user 0.09s system 99% cpu 3.852 total
% time ./IntegerMemo 3000    
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemo 3000  6.66s user 0.13s system 99% cpu 6.830 total

% time ./IntegerMemoGrap 1000 
1225560638892526472150132981770
./IntegerMemoGrap 1000  0.10s user 0.01s system 97% cpu 0.113 total
% time ./IntegerMemoGrap 2000
3122858113354873680008305238045814042010921833620857170165770
./IntegerMemoGrap 2000  0.97s user 0.04s system 98% cpu 1.028 total
% time ./IntegerMemoGrap 2500
4399449191298176980662410776849867104410434903220291205722799441218623242250
./IntegerMemoGrap 2500  2.11s user 0.08s system 99% cpu 2.202 total
% time ./IntegerMemoGrap 3000 
5947985907461048240178371687835977247601455563536278700587949163642187584269899171375349770
./IntegerMemoGrap 3000  3.33s user 0.09s system 99% cpu 3.452 total

看起来图形版本比 2 的常数因子更快。但它们似乎具有相同的时间复杂度:)

看来我的问题是通过简单地用单调版本替换 g_graph 定义中的 intSeq 来解决的:

-- replace vertexes for non-monotone integer lists by the according monotone one
monoIntSeq :: Graph [Integer]
monoIntSeq = f intSeq
  where f (Graph as t) | as == sort as = Graph as $ fmap f t
                       | otherwise     = fetch monIntSeq $ sort as

-- extract the subgraph after following the given labels
fetch :: Graph a -> [Integer] -> Graph a
fetch g [] = g
fetch (Graph _ t) (x:xs) = fetch (index t x) xs

g_graph :: Graph (StrictPair Integer [Integer])
g_graph = fmap (g g_m) monoIntSeq

非常感谢大家(尤其是 user5402 和 Oleg)的帮助!


编辑: 我的典型用例仍然存在内存消耗过高的问题,可以通过以下路径进行描述:

p :: [Integer]
p = map f [1..]
  where f n | n `mod` 6 == 0 = n `div` 6
            | n `mod` 3 == 0 = n `div` 3
            | n `mod` 2 == 0 = n `div` 2
            | otherwise      = n

一个小的改进是像这样直接定义单调整数序列:

-- extract the subgraph after following the given labels (right to left)
fetch :: Graph a -> [Integer] -> Graph a
fetch = foldl' step
  where step (Graph _ t) n = index t n

-- walk the graph following the given labels (right to left)
walk :: Graph a -> [Integer] -> a
walk g ns = a
  where Graph a _ = fetch g ns

-- all monotone falling integer sequences
monoIntSeqs :: Graph [Integer]
monoIntSeqs = Graph [] $ fmap (flip f monoIntSeqs) nats
  where f n (Graph ns t) | null ns      = Graph (n:ns) $ fmap (f n) t
                         | n >= head ns = Graph (n:ns) $ fmap (f n) t
                         | otherwise    = fetch monoIntSeqs (insert' n ns)
        insert' = insertBy (comparing Down)

但最后我可能只使用没有标识的原始整数序列,不时明确地标识节点并避免保留对 g_graph 等的引用,以便随着程序的进行进行垃圾收集清理。

阅读 Richard Bird 和 Ralf Hinze 的功能珍珠 Trouble Shared is Trouble Halved,我明白了如何实现,我两年前一直在寻找什么(再次基于 Edward Kmett 的技巧):

{-# LANGUAGE BangPatterns #-}
import Data.Function (fix)

data Tree a = Tree (Tree a) a (Tree a)
  deriving Show

instance Functor Tree where
  fmap f (Tree l m r) = Tree (fmap f l) (f m) (fmap f r)

index :: Tree a -> Integer -> a
index (Tree _ m _) 0 = m
index (Tree l _ r) n = case (n - 1) `divMod` 2 of
  (q,0) -> index l q
  (q,1) -> index r q

nats :: Tree Integer
nats = go 0 1
  where go !n !s = Tree (go l s') n (go r s')
          where l = n + s
                r = l + s
                s' = s * 2

data IntSeqTree a = IntSeqTree a (Tree (IntSeqTree a))

val :: IntSeqTree a -> a
val (IntSeqTree a _) = a

step :: Integer -> IntSeqTree t -> IntSeqTree t
step n (IntSeqTree _ ts) = index ts n

intSeqTree :: IntSeqTree [Integer]
intSeqTree = fix $ create []
  where create p x = IntSeqTree p $ fmap (extend x) nats
        extend x n = case span (>n) (val x) of
                       ([], p) -> fix $ create (n:p)
                       (m, p)  -> foldr step intSeqTree (m ++ n:p)

instance Functor IntSeqTree where
  fmap f (IntSeqTree a t) = IntSeqTree (f a) (fmap (fmap f) t)

在我的用例中,我有成百上千个递增生成的类似整数序列(长度为几百个条目)。所以对我来说,这种方式比在查找函数值之前对序列进行排序更便宜(我将通过在 intSeqTree 上使用 fmap 来访问它)。