我在使用 monad 时不断出错

I keep getting errors when I use monads

我正在尝试使用 monad 编写代码。难道我做错了什么? 我不明白为什么会出现错误。 “错误:此表达式的类型为 (expr -> 'a) -> 'a 但表达式应为 expr 类型”。 此错误出现在测试用例中。当我不使用 monad 时,它似乎工作正常。这是错误的使用方式吗?

type var = X | Y | Z;;  
type expr = N of int    
          | V of var
          | Add of expr * expr
          | Mul of expr * expr;;

let add a b = 
    match (a, b) with
    | ((N 0), y) -> y
    | (x, (N 0)) -> x
    | ((N x), (N y)) -> N (x + y)
    | _ -> Add (a, b)

let mul a b = 
    match (a, b) with
    | ((N 0), y) -> N 0
    | ((N 1), y) -> y
    | (x, (N 0)) -> N 0
    | (x, (N 1)) -> x
    | ((N x), (N y)) -> N (x * y)
    | _ -> Mul (a, b)

let id x = x 

module Cont = struct
    let ret x = fun k -> k x
    let (>>=) m f = fun k -> m (fun x -> (f x) k)
end

let rec deriv expr var =
    let open Cont in
    match expr with
    | N x -> ret (N 0)
    | V v -> if v = var then ret (N 1) else id ret (N 0)
    | Add (a, b) -> deriv a var >>= fun x -> 
                    deriv b var >>= fun y -> 
                    ret (add x y)
    | Mul (a, b) -> deriv b var >>= fun x -> 
                    deriv a var >>= fun y -> 
                    ret (add (mul a x)(mul b y))
    | _ -> assert false

let rec to_str expr =
    let open Printf in
    let open Cont in
    let var_to_str = function
        X -> "x" | Y -> "y" | Z -> "z" | _ -> assert false in
    match expr with
        | N a -> ret (sprintf "%d" a)
        | V v -> ret (var_to_str v)
        | Add (a, b) -> to_str a >>= fun x -> to_str b 
                                 >>= fun y -> 
                                 ret (sprintf "(%s + %s)" x y)
        | Mul (a, b) -> to_str a >>= fun x -> to_str b 
                                 >>= fun y -> 
                                 ret (sprintf "%s * %s" x y)
        | _ -> assert false

(*test cases*)    
let a = add (V X) (N 3)
let _ = to_str a                (* "(x + 3)" *)
let _ = to_str (deriv a X)      (* "1" *)
let _ = to_str (deriv a Y)      (* "0" *)

let b = add (mul (N 2) (V X)) (mul (V Y) (N 3))
let _ = to_str b                (* "(2 * x + y * 3)" *)
let _ = to_str (deriv b X)      (* "2" *)
let _ = to_str (deriv b Y)      (* "3" *)

let c = mul (mul (V X) (V Y)) (add (V X) (N 3))
let _ = to_str c                (* "x * y * (x + 3)" *)
let _ = to_str (deriv c X)      (* "(x * y + y * (x + 3))" *)
let _ = to_str (deriv c Y)      (* "x * (x + 3)" *)

在您的 测试用例中 您正在以直接方式调用 to_str -

(*test cases*)    
let a = add (V X) (N 3)
let _ = to_str a
let _ = to_str (deriv a X) (* deriv returns a Cont *)
let _ = to_str (deriv a Y) (* deriv returns a Cont *)

let b = add (mul (N 2) (V X)) (mul (V Y) (N 3))
let _ = to_str 
let _ = to_str (deriv b X) (* deriv returns a Cont *)
let _ = to_str (deriv b Y) (* deriv returns a Cont *)

let c = mul (mul (V X) (V Y)) (add (V X) (N 3))
let _ = to_str c
let _ = to_str (deriv c X) (* deriv returns a Cont *)
let _ = to_str (deriv c Y) (* deriv returns a Cont *)

但据我们所知,derivreturns延续!您必须通过 to_str 作为 继续 ...

(*test cases*)    
let a = add (V X) (N 3)
let _ = to_str a
let _ = (deriv a X) to_str (* (deriv ...) then continue with to_str *)
let _ = (deriv a Y) to_str (* (deriv ...) then continue with to_str *)

let b = add (mul (N 2) (V X)) (mul (V Y) (N 3))
let _ = to_str b
let _ = (deriv b X) to_str (* (deriv ...) then continue with to_str *)
let _ = (deriv b Y) to_str (* (deriv ...) then continue with to_str *)

let c = mul (mul (V X) (V Y)) (add (V X) (N 3))
let _ = to_str c
let _ = (deriv c X) to_str (* (deriv ...) then continue with to_str *)
let _ = (deriv c Y) to_str (* (deriv ...) then continue with to_str *)

请注意,to_str 也返回一个延续!那我们怎么打印呢? -

let _ = print_endline ((to_str a) id)          (* (x + 3) *)
let _ = print_endline ((deriv a X) to_str id)  (* 1 *)
let _ = ...

Continuations 将您的程序颠倒过来,由内而外!您可以将 print_endline 作为延续 -

传递,而不是使用 id 展开并将结果传递给 print_endline
let a = add (V X) (N 3)
let _ = (to_str a) print_endline
let _ = (deriv a X) to_str print_endline
let _ = (deriv a Y) to_str print_endline

let b = add (mul (N 2) (V X)) (mul (V Y) (N 3))
let _ = (to_str b) print_endline
let _ = (deriv b X) to_str print_endline
let _ = (deriv b Y) to_str print_endline

let c = mul (mul (V X) (V Y)) (add (V X) (N 3))
let _ = (to_str c) print_endline
let _ = (deriv c X) to_str print_endline
let _ = (deriv c Y) to_str print_endline

输出

(x + 3)
1
0
(2 * x + y * 3)
2
3
x * y * (x + 3)
(x * y + (x + 3) * y)
(x + 3) * x

这个我也看到了-

... else id ret (N 0)

可以安全地删除 id 的地方 -

... else ret (N 0)