如何在 Haskell 中并行减少这棵树?

How can I reduce this tree in parallel in Haskell?

我有一个简单的树,它在其叶子中存储一系列值和一些简单的函数以方便测试。

如果我有无限数量的处理器并且树是平衡的,我应该能够在对数时间内使用任何二进制关联运算(+、*、min、lcm)来减少树。

通过将 Tree 设为 Foldable 的实例,我可以使用内置函数从左到右或从右到左依次缩小树,但这需要线性时间。

如何使用Haskell并行减少这样一棵树?

{-# LANGUAGE DeriveFoldable #-}

data Tree a = Leaf a | Node (Tree a) (Tree a)
            deriving (Show, Foldable)

toList :: Tree a -> [a]
toList = foldr (:) []

range :: Int -> Int -> Tree Int
range x y
  | x < y     = Node (range x y') (range x' y)
  | otherwise = Leaf x
  where
    y' = quot (x + y) 2
    x' = y' + 1

天真的折叠是这样写的:

cata fLeaf fNode = go where
    go (Leaf z) = fLeaf z
    go (Node l r) = fNode (go l) (go r)

我想平行的会很简单地改编:

parCata fLeaf fNode = go where
    go (Leaf z) = fLeaf z
    go (Node l r) = gol `par` gor `pseq` fNode gol gor where
        gol = go l
        gor = go r

但甚至可以写成 cata:

parCata fLeaf fNode = cata fLeaf (\l r -> l `par` r `pseq` fNode l r)

更新

我最初是在假设归约操作并不昂贵的情况下回答这个问题的。这是一个对 n 元素的块执行关联归约的答案。

也就是说,假设 op 是一个关联二元运算,并且您想要计算 foldr1 op [1..6],这里的代码会将其计算为:

(op (op 1 2) (op 3 4)) (op 5 6)

允许并行计算。

import Control.Parallel.Strategies
import System.TimeIt
import Data.List.Split
import Debug.Trace

recChunk :: ([a] -> a) -> Int -> [a] -> a
recChunk op n xs =
  case chunksOf n xs of
    [a] -> op a
    cs  -> recChunk op n $ parMap rseq op cs

data N = N Int | Op [N]
  deriving (Show)

test1 = recChunk Op 2 $ map N [1..10]
test2 = recChunk Op 3 $ map N [1..10]

fib 0 = 0
fib 1 = 1
fib n = fib (n-1) + fib (n-2)

fib' n | trace msg False = undefined
  where msg = "fib called with " ++ show n
fib' n = fib n

sumFib :: [Int] -> Int
sumFib xs | trace msg False = undefined
  where msg = "sumFib: " ++ show xs
sumFib xs = seq s (s + (mod (fib' (40 + mod s 2)) 1))
  where s = sum xs

main = do
  timeIt $ print $ recChunk sumFib 2 [1..20]

原答案

因为你有一个关联操作,你可以只使用你的 toList 函数并与 parMapparList.

并行计算列表

这里有一些演示代码,可以将每个 Leaf 的 fib 相加。我使用 parBuffer 来避免产生太多火花 - 如果你的树很小,则不需要这样做。

我正在从文件中加载树,因为带有 -O2 的 GHC 似乎正在检测我的测试树中的常见子表达式。

此外,根据您的需要调整 rseq - 您可能需要 rdeepseq,具体取决于您积累的内容。

{-# LANGUAGE DeriveFoldable #-}

import Control.Parallel.Strategies
import System.Environment
import Control.DeepSeq
import System.TimeIt
import Debug.Trace

fib 0 = 0
fib 1 = 1
fib n = fib (n-1) + fib (n-2)

fib' n | trace msg False = undefined
  where msg = "fib called with " ++ show n
fib' n = fib n

data Tree a = Leaf a | Node (Tree a) (Tree a)
            deriving (Show, Read, Foldable)

toList :: Tree a -> [a]
toList = foldr (:) []

computeSum :: Int -> Tree Int -> Int
computeSum k t = sum $ runEval $ parBuffer k rseq $ map fib' $ toList t

main = do
  tree <- fmap read $ readFile "tree.in"
  timeIt $ print $ computeSum 4 tree
  return ()