Commits

Anonymous committed 175c586

init: Monads, ErrorMonad, StateMonad

Comments (0)

Files changed (4)

+^src/.*\.(glob|vo)$
+(*
+Usage:
+
+    Require Import ErrorMonad.
+
+    Module MyErrorType <: ERR_TYPE.
+      Definition E := <my_error_type>.
+    End MyErrorType.
+
+Then:
+
+1. Transformer:
+
+    Module Error_my_error_type := ErrorT (MyErrorType) (InnerMonad).
+    Import Error_my_error_type. (* for functions *)
+    Import Ops. (* for notations *)
+
+2. Monad:
+
+    Require Import ErrorMonad.
+    Module Error_my_error_type := Error (MyErrorType).
+    Import Error_my_error_type.
+    Import Ops.
+
+
+Specific functions:
+
+  - throw {A} : E -> m A
+  - catch : m A -> (E -> m A) -> m A
+  - run_error : m A -> InnerMonad.m (err E A)
+    where Inductive err E A := | Eok a | Eerr e.
+*)
+
+
+Require Import Monads.
+
+Inductive err E A :=
+| Eok : A -> err E A
+| Eerr : E -> err E A
+.
+
+Implicit Arguments Eok [E A].
+Implicit Arguments Eerr [E A].
+
+
+Module Type ERR_TYPE.
+  Parameter E : Type.
+End ERR_TYPE.
+
+Module Type MONADT_ET (Et : ERR_TYPE) (M : MONAD).
+  Include MONADT (M).
+End MONADT_ET.
+
+(* for MapError: *)
+Module Type ERROR_T (Et : ERR_TYPE) (M : MONAD).
+  Definition m A := M.m (err Et.E A).
+End ERROR_T.
+
+
+Module ErrorT (Et : ERR_TYPE) (M : MONAD)
+<: MONADT_ET (Et) (M)
+<: ERROR_T (Et) (M).
+
+  Module Mo.
+  Definition tm A := M.m (err Et.E A).
+
+  Definition lift A (ma : M.m A) : tm A :=
+    M.bind (fun a => M.ret (Eok a)) ma
+  .
+  Implicit Arguments lift [[A]].
+
+  Definition ret A (a : A) : tm A :=
+    M.ret (Eok a)
+  .
+  Global Implicit Arguments ret [[A]].
+
+  Definition bind A B
+    (f : A -> tm B)
+    (tma : tm A)
+   :
+    tm B
+   :=
+    M.bind
+      (fun ea =>
+         match ea with
+         | Eok a => f a
+         | Eerr e => M.ret (Eerr e)
+         end
+      )
+      tma
+  .
+  Global Implicit Arguments bind [A B].
+
+
+  Definition map_error {A} (mapper : Et.E -> Et.E) (tma : tm A) : tm A :=
+    M.bind
+      (fun ea =>
+         match ea with
+         | Eok _ => M.ret ea
+         | Eerr e => M.ret (Eerr (mapper e))
+         end
+      )
+      tma
+  .
+
+
+  Theorem lift_ret : forall A (x : A),
+    lift (M.ret x) = (ret x)
+  .
+  intros.
+  unfold lift.
+  rewrite M.bind_unit.
+  unfold ret.
+  reflexivity.
+  Qed.
+
+
+  Theorem bind_unit : forall A B (a : A) (f : A -> tm B),
+    bind f (ret a) = f a
+  .
+  intros.
+  unfold ret, bind.
+  rewrite M.bind_unit.
+  reflexivity.
+  Qed.
+
+  Require Import FunctionalExtensionality.
+
+  Theorem unit_bind : forall A (ma : tm A),
+    bind (@ret A) ma = ma
+  .
+  intros.
+  unfold ret, bind.
+  rewrite <- M.unit_bind.
+  f_equal.
+  extensionality ea.
+  destruct ea; now reflexivity.
+  Qed.
+
+  Theorem bind_bind : forall A B C
+      (ma : tm A) (f : A -> tm B) (g : B -> tm C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+  intros.
+  unfold bind.
+  rewrite M.bind_bind.
+  f_equal.
+  extensionality x.
+  destruct x.
+  (* 1 *)
+    reflexivity.
+  (* 2 *)
+    rewrite M.bind_unit.
+    reflexivity.
+  Qed.
+
+  End Mo.
+
+  Include Mo.
+
+  Module Bc := BindCompose(M).
+
+  Theorem lift_bind : forall A B
+      (ma : M.m A) (f : A -> M.m B),
+    lift (M.bind f ma) = bind (fun x => lift (f x)) (lift ma)
+  .
+  intros.
+  unfold bind, lift.
+  rewrite Bc.bind_compose.
+  rewrite <- M.bind_bind.
+  reflexivity.
+  Qed.
+
+
+  Definition m := tm.
+
+(*
+  Implicit Arguments lift [A].
+  Implicit Arguments ret [A].
+  Implicit Arguments bind [A B].
+*)
+
+  (* specific functions *)
+
+  Definition throw {A} (e : Et.E) : m A := M.ret (@Eerr Et.E A e).
+
+  Definition catch {A} (ma : m A)
+    (handler : Et.E -> m A)
+   :
+    m A
+   :=
+    M.bind
+      (fun ea =>
+         match ea with
+         | Eok _ => M.ret ea
+         | Eerr e => handler e
+         end
+      )
+      ma
+  .
+
+  Definition run_error {A} (tma : tm A) : M.m (err Et.E A) :=
+    tma
+  .
+
+  Module Ops := MTinfix(Mo).
+
+  Module Inh := Minher(Mo).
+  Include Inh.
+
+End ErrorT.
+
+
+Module MapError
+ (M : MONAD)
+ (T1 : ERR_TYPE) (ET1 : ERROR_T (T1) (M))
+ (T2 : ERR_TYPE) (ET2 : ERROR_T (T2) (M))
+.
+  Definition map_error {A}
+   (mapper : T1.E -> T2.E)
+   (m1 : ET1.m A) : (ET2.m A)
+   :=
+    M.bind
+      (fun ea =>
+         match ea with
+         | Eok a => M.ret (Eok a)
+         | Eerr e => M.ret (Eerr (mapper e))
+         end
+      )
+      m1
+  .
+End MapError.
+
+
+(*
+Module Example_Et <: ERR_TYPE.
+  Definition E := nat.
+End Example_Et.
+
+Module Error_nat := ErrorT (Example_Et) (Identity).
+Import Error_nat.
+
+Eval compute in (ret tt).
+
+Eval compute in
+  (catch
+     (bind
+        (fun t =>
+           catch
+             (bind
+                (fun a =>
+                   bind
+                     (fun b =>
+                        if true (* false *)
+                        then
+                          throw (a + b)
+                        else
+                          ret (a + b)
+                     )
+                     (ret 3)
+                )
+                (ret 2)
+             )
+             (fun e => throw (e + 10))
+        )
+        (ret true)
+     )
+     (fun e => throw (e + 100))
+  )
+.
+
+
+Import Error_nat.Ops.
+
+Eval compute in
+  (catch
+     (ret true >>= fun t =>
+      catch
+        (ret 2 >>= fun a =>
+         ret 3 >>= fun b =>
+         if (* true *) false
+         then
+           throw (a + b)
+         else
+           ret (a + b)
+        )
+        (fun e => throw (e + 10))
+     )
+     (fun e => throw (e + 100))
+  )
+.
+
+
+Eval compute in
+  (catch
+     (t <- ret true;;
+      catch
+        (a <- ret 2;;
+         b <- ret 3;;
+         c <- lift $ Identity.ret 4;;
+         if (* true *) false
+         then
+           throw $ a + b + c
+         else
+           ret $ a + b + c
+        )
+        (fun e => throw (e + 10))
+     )
+     (fun e => throw (e + 100))
+  )
+.
+*)
+
+Module Error (Et : ERR_TYPE)
+<: MONAD
+<: ERROR_T (Et) (Identity).
+
+  Module Impl := ErrorT (Et) (Identity).
+
+  Include Impl.
+
+End Error.
+Module Type MONADT_RAW.
+  Parameter tm : Type -> Type.
+  Parameter ret : forall A (a : A), tm A.
+  Parameter bind : forall A B, (A -> tm B) -> tm A -> tm B.
+  Implicit Arguments ret [[A]].
+  Implicit Arguments bind [A B].
+
+  Parameter bind_unit : forall A B (a : A) (f : A -> tm B),
+    bind f (ret a) = f a
+  .
+
+  Parameter unit_bind : forall A (ma : tm A),
+    bind (@ret A) ma = ma
+  .
+
+  Parameter bind_bind : forall A B C
+      (ma : tm A) (f : A -> tm B) (g : B -> tm C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+
+End MONADT_RAW.
+
+Module Type For_infix.
+  Parameter m : Type -> Type.
+  Parameter ret : forall A (a : A), m A.
+  Parameter bind : forall A B, (A -> m B) -> m A -> m B.
+  Implicit Arguments ret [[A]].
+  Implicit Arguments bind [A B].
+End For_infix.
+
+
+Module Minfix (M : For_infix).
+
+  Notation "m >>= f" := (@M.bind _ _ f m)
+    (at level 57, right associativity)
+  .
+
+  Notation "x <- y ;; z" :=
+    (@M.bind _ _ (fun x : _ => z) y)
+    ( (*only parsing,*)
+      at level 66
+    , right associativity
+    , y at next level
+    )
+  .
+
+  Notation "y ;; z" := (@M.bind _ _ (fun (_u : unit) => z) y)
+    ( (*only parsing,*)
+       at level 66
+    , right associativity
+    )
+  .
+
+  Notation "f $ x" := (f x)
+    (at level 63, right associativity, only parsing)
+  .
+
+End Minfix.
+
+
+Module MTinfix (M : MONADT_RAW).
+  Module F <: For_infix.
+    Include M.
+    Definition m := tm.
+  End F.
+  Module O := Minfix(F).
+  Include O.
+End MTinfix.
+
+
+(*
+Module Type MONAD_LAWS (M : MONADT_RAW).
+
+  Import M.
+
+  Parameter bind_unit : forall A B (a : A) (f : A -> tm B),
+    bind f (ret a) = f a
+  .
+
+  Parameter unit_bind : forall A (ma : tm A),
+    bind (@ret A) ma = ma
+  .
+
+  Parameter bind_bind : forall A B C
+      (ma : tm A) (f : A -> tm B) (g : B -> tm C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+
+End MONAD_LAWS.
+*)
+
+Module Type MONAD.
+
+  Include MONADT_RAW.
+
+  Definition m := tm.
+
+End MONAD.
+
+
+Module Identity <: MONAD.
+
+  Module Raw <: MONADT_RAW.
+
+    Definition m (A : Type) := A.
+    Definition tm := m. (* for MTinfix *)
+
+    Definition ret {A} (x : A) := x.
+
+    Definition bind {A B} (f : A -> m B) (ma : m A) : m B :=
+      f ma.
+
+  Theorem bind_unit : forall A B (a : A) (f : A -> m B),
+    bind f (ret a) = f a
+  .
+  auto.
+  Qed.
+
+  Theorem unit_bind : forall A (ma : m A),
+    bind (@ret A) ma = ma
+  .
+  auto.
+  Qed.
+
+  Theorem bind_bind : forall A B C
+      (ma : m A) (f : A -> m B) (g : B -> m C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+  auto.
+  Qed.
+
+  End Raw.
+
+  Include Raw.
+
+  Module Ops := MTinfix(Raw).
+
+(*
+    Notation "m >>= f" := (@bind _ _ f m)
+      (at level 57, right associativity)
+    .
+*)
+
+End Identity.
+
+
+Module Type MONADT (M : MONAD) <: MONAD.
+
+  Module Mo.
+  Parameter tm : Type -> Type.
+
+  Parameter ret : forall A (x : A), tm A.
+  Global Implicit Arguments ret [[A]].
+
+  Parameter bind : forall A B
+    (f : A -> tm B)
+    (tma : tm A),
+    tm B
+  .
+  Global Implicit Arguments bind [A B].
+
+  Parameter bind_unit : forall A B (a : A) (f : A -> tm B),
+    bind f (ret a) = f a
+  .
+
+  Parameter unit_bind : forall A (ma : tm A),
+    bind (@ret A) ma = ma
+  .
+
+  Parameter bind_bind : forall A B C
+      (ma : tm A) (f : A -> tm B) (g : B -> tm C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+
+  End Mo.
+
+  Import Mo.
+
+  Parameter lift : forall A, M.m A -> tm A.
+
+  Global Implicit Arguments lift [[A]].
+
+  Parameter lift_ret : forall A (x : A),
+    lift (M.ret x) = (ret x)
+  .
+
+  Parameter lift_bind : forall A B
+      (ma : M.m A) (f : A -> M.m B),
+    lift (M.bind f ma) = bind (fun x => lift (f x)) (lift ma)
+  .
+
+  (* as a monad: *)
+
+  Definition m := tm.
+
+  Include Mo.
+
+  Module Ops := MTinfix(Mo).
+
+  Parameter fmap : forall {A B} (f : A -> B) (ma : m A), m B.
+
+End MONADT.
+
+
+Module Minher (M : MONADT_RAW).
+
+  Import M.
+
+  Definition fmap {A B} (f : A -> B) (ma : tm A) : tm B :=
+    M.bind
+      (fun a => M.ret (f a))
+      ma
+  .
+
+End Minher.
+
+
+Module BindCompose (M : MONADT_RAW).
+
+  Require Import FunctionalExtensionality.
+
+  Theorem bind_compose : forall A B C m
+    (f : A -> B)
+    (g : B -> M.tm C),
+    M.bind
+      (fun b => g b)
+      (M.bind
+         (fun a => M.ret (f a))
+         m
+      )
+    =
+    M.bind
+      (fun a => g (f a)
+      )
+      m
+  .
+  intros.
+  rewrite M.bind_bind.
+  f_equal.
+  extensionality x.
+  rewrite M.bind_unit.
+  reflexivity.
+  Qed.
+
+End BindCompose.
+(*
+Usage:
+
+    Require Import StateMonad.
+
+    Module MyState <: ST_TYPE.
+      Definition S := <type_of_my_state>.
+    End MyState.
+
+1. Transformer:
+
+    Module State_my := StateT (MyState) (InnerMonad).
+    Import State_my.  (* for functions *)
+    Import Ops.  (* for notations *)
+
+2. Monad:
+
+    Module State_my := State (MyState).
+    Import State_my.  (* for functions *)
+    Import Ops.  (* for notations *)
+
+
+Specific functions:
+
+  - get : m S
+  - put : S -> m unit
+  - run_state : S (*initial state*) -> m A -> InnerMonad.m (A * S)
+*)
+
+
+Require Import Monads.
+
+Module Type ST_TYPE.
+  Parameter S : Type.
+End ST_TYPE.
+
+Module Type MONADT_S (S : ST_TYPE) (M : MONAD).
+  Include MONADT (M).
+End MONADT_S.
+
+Module StateT (Stt : ST_TYPE) (M : MONAD) <: MONADT_S (Stt) (M).
+
+  Module Mo.
+
+  Definition tm A := Stt.S -> M.m (A * Stt.S).
+
+  Definition ret A (x : A) : tm A := fun s => M.ret (x, s).
+  Global Implicit Arguments ret [[A]].
+
+  Definition bind A B
+    (f : A -> tm B)
+    (tma : tm A)
+   :
+    tm B
+   :=
+    fun s1 =>
+      let mas1 := tma s1 in
+      M.bind
+        (fun a_s2 =>
+           let (a, s2) := (a_s2 : (A * Stt.S)) in
+           (f a) s2
+        )
+        mas1
+  .
+  Global Implicit Arguments bind [A B].
+
+  Require Import FunctionalExtensionality.
+
+  Theorem bind_unit : forall A B (a : A) (f : A -> tm B),
+    bind f (ret a) = f a
+  .
+  intros.
+  unfold bind, ret.
+  extensionality s.
+  rewrite M.bind_unit.
+  reflexivity.
+  Qed.
+
+  Theorem unit_bind : forall A (ma : tm A),
+    bind (@ret A) ma = ma
+  .
+  intros.
+  unfold bind, ret.
+  extensionality s.
+  rewrite <- M.unit_bind.
+  f_equal.
+  extensionality a_s2.
+  destruct a_s2.
+  reflexivity.
+  Qed.
+
+  Theorem bind_bind : forall A B C
+      (ma : tm A) (f : A -> tm B) (g : B -> tm C),
+    bind g (bind f ma)
+    =
+    bind (fun x => bind g (f x)) ma
+  .
+  intros.
+  unfold bind.
+  extensionality s1.
+  rewrite M.bind_bind.
+  f_equal.
+  extensionality a_s2.
+  destruct a_s2 as [a s2].
+  reflexivity.
+  Qed.
+
+  End Mo.
+
+  Include Mo.
+
+  Definition lift A (ma : M.m A) : tm A :=
+    fun s =>
+      M.bind
+        (fun a =>
+           M.ret (a, s)
+        )
+        ma
+  .
+  Global Implicit Arguments lift [[A]].
+
+
+  Require Import FunctionalExtensionality.
+
+  Theorem lift_ret : forall A (x : A),
+    lift (M.ret x) = (ret x)
+  .
+  intros.
+  unfold lift, ret.
+  extensionality s.
+  rewrite M.bind_unit.
+  reflexivity.
+  Qed.
+
+
+  Module Bc := BindCompose(M).
+
+  Theorem lift_bind : forall A B
+      (ma : M.m A) (f : A -> M.m B),
+    lift (M.bind f ma) = bind (fun x => lift (f x)) (lift ma)
+  .
+  intros.
+  unfold lift, ret, bind.
+  extensionality s.
+  rewrite Bc.bind_compose.
+  rewrite <- M.bind_bind.
+  reflexivity.
+  Qed.
+
+  (* as a monad: *)
+
+  Definition m := tm.
+
+  Module Ops := MTinfix(Mo).
+
+
+  (* specific functions: *)
+
+  Definition put (s : Stt.S) : m unit :=
+    fun old_s =>
+      M.ret (tt, s)
+  .
+
+  Definition get : m Stt.S :=
+    fun s =>
+      M.ret (s, s)
+  .
+
+  Definition run_state {A} (s : Stt.S) (tma : tm A) : M.m (A * Stt.S) :=
+    tma s
+  .
+
+
+  Module Inh := Minher(Mo).
+  Include Inh.
+
+
+End StateT.
+
+
+Module State (St : ST_TYPE) <: MONAD.
+
+  Module Impl := StateT (St) (Identity).
+
+  Include Impl.
+
+End State.
+
+
+
+(*
+Module MyState <: ST_TYPE.
+  Definition S := nat.
+End MyState.
+
+Module State_nat := StateT (MyState) (Identity).
+Import State_nat.
+Import Ops.
+
+Eval compute in
+  ( run_state
+      5
+      ( a <- ret 3;;
+        b <- get;;
+        put $ a + b;;
+        c <- get;;
+        ret $ c + 4
+      )
+  ).
+*)