Haskell - 总结概率列表

Haskell - Sum up probability list

我正在学习学习 A Haskell,刚刚完成 "For a few monads more" 部分。在这部分我们创建了一个 newtype Prob a = Prob { getProb :: [(a, Rational)] } 并为它创建了一个 Monad 实例。这使我们能够计算非确定性计算中结果的概率,如下所示:

data Coin = Heads | Tails deriving (Show, Eq)

coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

loadedCoin :: Prob Coin
loadedCoin = Prob [(Heads, 1%10), (Tails, 9%10)]

coinTest :: Prob Bool
coinTest = do
    a <- coin
    b <- coin
    c <- loadedCoin
    return (all (==Tails) [a,b,c])

当然,这不会产生非常漂亮的结果:

getProb coinTest
>> [(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(False,9 % 40),(False,1 % 40),(True,9 % 40)]

留给 reader 的练习是编写一个简洁的函数来总结所有 Falses 和所有 Trues 所以我们得到 [(True,9 % 40),(False,31 % 40)].我设法做到了这一点。它适用于这种特殊情况,但我觉得它根本不是一个有用的功能,因为它太专业了。这是我想出的:

sumProbs :: Prob Bool -> Prob Bool
sumProbs (Prob ps) = let (trues, falses) = partition fst ps
                         ptrue = reduce trues
                         pfalse = reduce falses
                     in Prob [ptrue, pfalse]
                     where reduce = foldr1 (\(b,r) (_,r') -> (b,r+r'))

我很乐意将其概括为适用于任何 Eq a => Prob a,但到目前为止还没有成功。我正在考虑将 MapunionWith 或类似的东西一起使用。或者也许我可以利用 (a,b) 有一个 Functor b 实例这一事实?我想我缺少一些更简单更优雅的解决方案。

所以,总结一下:我如何编写一个函数 sumProbs :: (Eq a) => Prob a -> Prob a 来总结共享相同值(键)的所有概率?

使用 Map 是个好主意,但除了 Eq a 之外,您还需要 Ord a。如果你同意,那么我们还可以做更简单的列表解决方案:只需将 partition 替换为 sortBy and groupBy:

的组合
import Data.List (groupBy, sortBy)
import Data.Function (on)

sumProbs :: (Ord a) => Prob a -> Prob a
sumProbs (Prob ps) = Prob . map reduce
                   . groupBy ((==)`on`fst)
                   $ sortBy (compare`on`fst) ps

如果您使用 Data.Map,那么 fromListWith and toList 将执行:

import Data.Map (toList, fromListWith)

newtype Prob a = Prob { getProb :: [(a, Rational)] }
    deriving Show

sumProbs :: (Ord a) => Prob a -> Prob a
sumProbs = Prob . toList . fromListWith (+) . getProb

Ord a 放宽到 Eq a 将需要较低效率的二次计算;类似于:

sumProbs :: (Eq a) => Prob a -> Prob a
sumProbs = Prob . foldr go [] . getProb
    where
    go (x, y) = run
        where
        run [] = (x, y):[]
        run ((a, b):rest)
            | x == a    = (x, y + b): rest
            | otherwise = (a, b): run rest