在 Haskell 中使用 RecursiveDo 生成重复的 EDSL 代码

Duplicate EDSL code generated with RecursiveDo in Haskell

使用 ghc v8.0.1,使用 -O2 标志编译。

我在使用 RecursiveDo (mdo) 时遇到问题。有两个略有不同的函数应该产生相同的输出,但它们却没有。

以下函数产生正确的输出:

proc2 :: Assembler ()
proc2 = mdo
    set (R 0) (I 0x5a5a)
    let r = (R 0)
    let bits = (I 2)
    let count = (R 70)
    set count bits
    _loop <- label
    cmp count (I 0)
    je _end
    add r r
    sub count (I 1)
    jmp _loop
    _end <- label
    end

正确的输出是

0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 2)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  END

以下函数产生不正确的输出:

proc1 :: Assembler ()
proc1 = mdo
    set (R 0) (I 0x5a5a)
    shl (R 0) (I 1)
    end

shl :: (MonadFix m, Instructions m) => Operand -> Operand -> m ()
shl r@(R _) bits = mdo
    let count = (R 70)
    set count bits
    repeatN count $ mdo
        add r r     -- shift left by one
shl _ _ = undefined

repeatN :: (MonadFix m, Instructions m) => Operand -> m a -> m a
repeatN n@(R _) body = mdo
    _loop <- label
    cmp n (I 0)
    je _end
    retval <- body
    sub n (I 1)
    jmp _loop
    _end <- label
    return retval
repeatN _ _ = undefined

错误的输出是

0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 1)
0002:  CMP (R 70) (I 0)

0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)

0007:  JE (A 7)
0008:  ADD (R 0) (R 0)
0009:  SUB (R 70) (I 1)
000A:  JMP (A 2)

000B:  END

从 0007 到 000A 的行是从 0003 到 0006 的行的副本,并且(在这种特殊情况下)最终结果是在 0007 处的无限循环。

有问题的代码在 Haskell 中实现了一个 EDSL(Ting Pen 的汇编程序)。程序的输出是听笔的机器码。

我正在使用 MonadFix 来捕获汇编语言中的前向标签,当我使用一些代码组合器时,我得到了不正确的输出(一些生成的代码被复制)。我已经包含了一些跟踪代码并且能够跟踪代码生成。 RecursiveDo 机制有时会生成重复代码(另请参见下面提供的程序的输出)。

{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE MagicHash #-}

module TingBugChase1 where

import Data.Word (Word16)

import Control.Applicative (Applicative(..))
import Control.Monad (liftM, ap, return)
import Control.Monad.Fix (MonadFix(..))
import Text.Printf (printf)


i# :: (Integral a, Num b) => a -> b
i# = fromIntegral

-- ================================================================= Assembler

data Instruction = END
    | CLEARVER
    | SET Operand Operand
    | CMP Operand Operand
    | AND Operand Operand
    | OR Operand Operand
    | NOT Operand
    | JMP Operand
    | JE Operand
    | JNE Operand
    | JG Operand
    | JGE Operand
    | JB Operand
    | JBE Operand
    | ADD Operand Operand
    | SUB Operand Operand
    | RETURN
    | CALLID Operand
    | PLAYOID Operand
    | PAUSE Operand
    {- … -}
    deriving Show


data AsmState = AsmState
    { _code :: [Instruction]
    , _location :: Location
    , _codeHistory :: [([Instruction],[Instruction])]
    }

disasmCode :: [Instruction] -> Int -> [String]
disasmCode [] _ = ["[]"]
disasmCode code pc = map disasm1 $ zip [0..] code
    where
        disasm1 :: (Int, Instruction) -> String
        disasm1 (addr, instr) = printf "%04X:%s %s" addr (pointer addr) (show instr)
        pointer :: Int -> String
        pointer addr = if addr == i# pc then ">" else " "

instance Show AsmState where
    show (AsmState {..}) = "AsmState {" ++
        unlines
        [ "Code:\n" ++ unlines (disasmCode _code 0)
        , "Location: " ++ (show _location)
        , "History:\n" ++ unlines (map disasmHistory _codeHistory)
        ] ++ "}"
        where
            disasmHistory (a,b) = 
                unlines $
                    disasmCode a 0
                    ++ ["++"] ++
                    disasmCode b 0


data Assembler a = Assembler { runAsm :: AsmState -> (a, AsmState) }

-- https://wiki.haskell.org/Functor-Applicative-Monad_Proposal
-- Monad (Assembler w)
instance Functor Assembler where
    fmap = liftM

instance Applicative Assembler where
    {- move the definition of `return` from the `Monad` instance here -}
    pure a = Assembler $ \s -> (a,s)
    (<*>) = ap

instance Monad Assembler where
    return = pure -- redundant since GHC 7.10 due to default impl
    x >>= fy = Assembler $ \s -> 
            let 
                (a, sA) = runAsm x s
                (b, sB) = runAsm (fy a) sA
            in (b, 
                sB 
                { _code = _code sA ++ _code sB
                , _location = _location sB
                , _codeHistory = _codeHistory sB ++ [(_code sA, _code sB)]
                })

instance MonadFix Assembler where
    mfix f = Assembler $ \s -> 
        let (a, sA) = runAsm (f a) s 
        in (a, sA)

{- Append the list of instructions to the code stream. -}
append :: [Instruction] -> Assembler ()
append xs = Assembler $ \s -> 
    ((), s { _code = xs, _location = newLoc $ _location s })
    where
        newLoc (A loc) = A $ loc + (i# . length $ xs)
        newLoc _ = undefined

-- ========================================================= Instructions

data Operand = 
    R Word16    -- registers
    | I Word16  -- immediate value (integer)
    | A Word16  -- address (location)
    deriving (Eq, Show)

type Location = Operand

-- Instructions
class Instructions m where
    end :: m ()
    clearver :: m ()
    set :: Operand -> Operand -> m ()
    cmp :: Operand -> Operand -> m ()
    and :: Operand -> Operand -> m ()
    or :: Operand -> Operand -> m ()
    not :: Operand -> m ()
    jmp :: Location -> m ()
    je :: Location -> m ()
    jne :: Location -> m ()
    jg :: Location -> m ()
    jge :: Location -> m ()
    jb :: Location -> m ()
    jbe :: Location -> m ()
    add :: Operand -> Operand -> m ()
    sub :: Operand -> Operand -> m ()
    ret :: m ()
    callid :: Operand -> m ()
    playoid :: Operand -> m ()
    pause :: Operand -> m ()

    label :: m Location


{- Code combinators -}
repeatN :: (MonadFix m, Instructions m) => Operand -> m a -> m a
repeatN n@(R _) body = mdo
    _loop <- label
    cmp n (I 0)
    je _end
    retval <- body
    sub n (I 1)
    jmp _loop
    _end <- label
    return retval
repeatN _ _ = undefined

{- 
    Derived (non-native) instructions, aka macros 
    Scratch registers r70..r79
-}
shl :: (MonadFix m, Instructions m) => Operand -> Operand -> m ()
shl r@(R _) bits = mdo
    -- allocate registers
    let count = (R 70)

    set count bits
    repeatN count $ mdo
        add r r     -- shift left by one
shl _ _ = undefined


instance Instructions Assembler where 
    end = append [END]
    clearver = append [CLEARVER]
    set op1 op2 = append [SET op1 op2]
    cmp op1 op2 = append [CMP op1 op2]
    and op1 op2 = append [AND op1 op2]
    or op1 op2 = append [OR op1 op2]
    not op1 = append [NOT op1]

    jmp op1 = append [JMP op1]
    je op1 = append [JE op1]
    jne op1 = append [JNE op1]
    jg op1 = append [JG op1]
    jge op1 = append [JGE op1]
    jb op1 = append [JB op1]
    jbe op1 = append [JBE op1]

    add op1 op2 = append [ADD op1 op2]
    sub op1 op2 = append [SUB op1 op2]

    ret = append [RETURN]
    callid op1 = append [CALLID op1]
    playoid op1 = append [PLAYOID op1]
    pause op1 = append [PAUSE op1]

    {- The label function returns the current index of the output stream. -}
    label = Assembler $ \s -> (_location s, s { _code = [] })

-- ========================================================= Tests

asm :: Assembler () -> AsmState
asm proc = snd . runAsm proc $ AsmState 
            { _code = []
            , _location = A 0
            , _codeHistory = [] 
            }

doTest :: Assembler () -> String -> IO ()
doTest proc testName = do
    let ass = asm proc
    putStrLn testName
    putStrLn $ show ass

proc1 :: Assembler ()
proc1 = mdo
    set (R 0) (I 0x5a5a)
    shl (R 0) (I 1)
    end

proc2 :: Assembler ()
proc2 = mdo
    set (R 0) (I 0x5a5a)
    -- allocate registers
    let r = (R 0)
    let bits = (I 2)
    let count = (R 70)

    set count bits
    _loop <- label
    cmp count (I 0)
    je _end
    add r r
    sub count (I 1)
    jmp _loop
    _end <- label
    end

-- ========================================================= Main

main :: IO ()
main = do
    doTest proc1 "Incorrect Output"
    doTest proc2 "Correct Output"

程序的输出如下。

proc1 的错误输出:

AsmState {Code:
0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 1)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)

0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  JE (A 7)
0008:  ADD (R 0) (R 0)
0009:  SUB (R 70) (I 1)
000A:  JMP (A 2)
000B:  END

Location: A 8
History:
[]
++
[]

0000:> JMP (A 2)
++
[]

0000:> SUB (R 70) (I 1)
++
0000:> JMP (A 2)

0000:> ADD (R 0) (R 0)
++
0000:> SUB (R 70) (I 1)
0001:  JMP (A 2)

0000:> JE (A 7)
++
0000:> ADD (R 0) (R 0)
0001:  SUB (R 70) (I 1)
0002:  JMP (A 2)

这是代码重复发生的地方:

0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)

0000:> CMP (R 70) (I 0)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
0004:  JE (A 7)
0005:  ADD (R 0) (R 0)
0006:  SUB (R 70) (I 1)
0007:  JMP (A 2)

[]
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  JE (A 7)
0006:  ADD (R 0) (R 0)
0007:  SUB (R 70) (I 1)
0008:  JMP (A 2)

0000:> SET (R 70) (I 1)
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  JE (A 7)
0006:  ADD (R 0) (R 0)
0007:  SUB (R 70) (I 1)
0008:  JMP (A 2)

0000:> SET (R 70) (I 1)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  JE (A 7)
0007:  ADD (R 0) (R 0)
0008:  SUB (R 70) (I 1)
0009:  JMP (A 2)
++
0000:> END

0000:> SET (R 0) (I 23130)
++
0000:> SET (R 70) (I 1)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  JE (A 7)
0007:  ADD (R 0) (R 0)
0008:  SUB (R 70) (I 1)
0009:  JMP (A 2)
000A:  END
}

proc2 的正确输出:

AsmState {Code:
0000:> SET (R 0) (I 23130)
0001:  SET (R 70) (I 2)
0002:  CMP (R 70) (I 0)
0003:  JE (A 7)
0004:  ADD (R 0) (R 0)
0005:  SUB (R 70) (I 1)
0006:  JMP (A 2)
0007:  END

Location: A 8
History:
[]
++
[]

0000:> JMP (A 2)
++
[]

0000:> SUB (R 70) (I 1)
++
0000:> JMP (A 2)

0000:> ADD (R 0) (R 0)
++
0000:> SUB (R 70) (I 1)
0001:  JMP (A 2)

0000:> JE (A 7)
++
0000:> ADD (R 0) (R 0)
0001:  SUB (R 70) (I 1)
0002:  JMP (A 2)

0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
++
0000:> END

0000:> CMP (R 70) (I 0)
++
0000:> JE (A 7)
0001:  ADD (R 0) (R 0)
0002:  SUB (R 70) (I 1)
0003:  JMP (A 2)
0004:  END

[]
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  END

0000:> SET (R 70) (I 2)
++
0000:> CMP (R 70) (I 0)
0001:  JE (A 7)
0002:  ADD (R 0) (R 0)
0003:  SUB (R 70) (I 1)
0004:  JMP (A 2)
0005:  END

0000:> SET (R 0) (I 23130)
++
0000:> SET (R 70) (I 2)
0001:  CMP (R 70) (I 0)
0002:  JE (A 7)
0003:  ADD (R 0) (R 0)
0004:  SUB (R 70) (I 1)
0005:  JMP (A 2)
0006:  END
}

我认为问题在于您的 monad 实例存在缺陷。它看起来好像应该是一个 State monad,但是 >>= 的定义做了一些看起来更像 Writer monad 的操作(使用 []Last 幺半群)。我很确定至少 mfix>>= 不兼容,但我的猜测是即使 >>= 本身也可能不符合 monad 法则。

我没有寻找好的反例,但我可以为您提供一个我认为可行的版本,该版本基于可用的标准工具。我将 post 仅更改部分,我的整个测试代码可用 here:

import Control.Monad.RWS
import Control.Monad.Fix (MonadFix(..))
import qualified Data.Foldable as F
import qualified Data.Sequence as S
import Text.Printf (printf)

-- ...

状态是当前位置,writer monoid 是一系列指令(我使用 Seq 而不是 [],因为 ++ 的性能可能很差列表;如果需要,您也可以使用 DList):

newtype Assembler a = Assembler (RWS () (S.Seq Instruction) Word16 a)
    deriving (Functor, Applicative, Monad, MonadFix)

{- Append the list of instructions to the code stream. -}
append :: [Instruction] -> Assembler ()
append xs = Assembler . rws $ \_ loc -> ((), loc + (i# . length $ xs), S.fromList xs)

asm :: Assembler () -> AsmState
asm (Assembler proc) =
    let (location, code) = execRWS proc () 0
    in AsmState { _code = F.toList code
                , _location = A location
                , _codeHistory = [] -- I just ignored this field ...
                }

instance Instructions Assembler where 
    -- all same as before, except
    label = A <$> Assembler get

现在将 AsmState 重命名为 AsmResult.

可能更有意义

此外,我建议像我一样只使用 Word16,而不是将 Location 与部分函数一起使用,或者定义一个只捕获位置的新类型,然后在 [=28] 中使用它=].这使代码更安全、更简洁。

(无论如何,最好有一个可以测试此类问题的测试套件。)