Source

ocaml-toys / brainfuck / brainfuck.ml

Full commit
(** The abstract syntax tree (AST) that will be build then executed *)
module ParseTree = struct
  type operation =
    | Move of int            (** add n to the cell pointer *)
    | Add of int             (** add n to the current cell's value *)
    | Output                 (** write the current cell as an ascii char *)
    | Input                  (** read a char and write it in the current cell *)
    | Loop of operation list (** loop while the current cell is <> 0 *)
    (* following are optimizations *)
    | Reset                  (** reset the current cell to 0 *)
    | AddMultToCell of int * int(** add (current cell value)*n to the cell distant of i, then reset the current cell *)
    | AddMultToCell2 of int * int * int (** add (curval)*n to the both cells distant of i and j, then reset current *)
    | CopyMultToCell of int * int (** add curval*n to the cell distant of i, don't reset current *)
    | AddTo of int * int (** add n to the cell distant of i, without moving *)

  let string_of_op ops =
    let open Printf in
    let rec to_string indent = function
      | Move i -> sprintf "%sMove(%d)" indent i
      | Add i -> sprintf "%sAdd(%d)" indent i
      | Output -> sprintf "%sOutput" indent
      | Input -> sprintf "%sInput" indent
      | Loop nodes ->
          sprintf "%sLoop <<\n%s\n%sLoop >>"
            indent
            (String.concat "\n" (List.map (to_string ("| "^indent)) nodes))
            indent
      | Reset -> sprintf "%sReset" indent
      | AddMultToCell (n, i) -> sprintf "%sAddMultToCell(%d, %d)" indent n i
      | AddMultToCell2 (n, i, j) -> sprintf "%sAddMultToCell2(%d, %d, %d)" indent n i j
      | CopyMultToCell (n, i) -> sprintf "%sCopyMultToCell(%d, %d)" indent n i
      | AddTo (n, i) -> sprintf "%sAddTo(%d, %d)" indent n i
    in to_string "" ops
  
  let dump ast =
    List.iter (fun op -> print_endline (string_of_op op)) ast
end

(** Parse the source file and build the AST *)
module Parser = struct
  open ParseTree
  
  (* lexical analysis, build token list from chars *)
  
  type token =
    | IncrPtr | DecrPtr
    | IncrData | DecrData
    | Write | Read
    | Open | Close
    | Comment of char
  
  let token_of_char = function
    | '>' -> IncrPtr
    | '<' -> DecrPtr
    | '+' -> IncrData
    | '-' -> DecrData
    | '.' -> Write
    | ',' -> Read
    | '[' -> Open
    | ']' -> Close
    | c -> Comment c
  
  let tokenize charstream =
    let tokens = ref [] in
    Stream.iter (fun c -> tokens := (token_of_char c) :: !tokens) charstream;
    List.rev !tokens
  
  (* syntaxic analysis, build AST from tokens *)
  
  let build_loop tokens =
    let rec loop acc opened = function
      | Open :: rest -> loop (Open :: acc) (succ opened) rest
      | Close :: rest ->
          if opened = 0 then (List.rev acc), rest
          else loop (Close :: acc) (pred opened) rest
      | other:: rest -> loop (other :: acc) opened rest
      | [] -> failwith "malformed Loop"
    in
    loop [] 0 (List.tl tokens)
  
  let rec build_ast tokens =
    let rec loop tokens acc =
      match tokens with
      | [] -> acc
      | IncrPtr :: rest -> loop rest (Move 1 :: acc)
      | DecrPtr :: rest -> loop rest (Move (-1) :: acc)
      | IncrData :: rest -> loop rest (Add 1 :: acc)
      | DecrData :: rest -> loop rest (Add (-1) :: acc)
      | Write :: rest -> loop rest (Output :: acc)
      | Read :: rest -> loop rest (Input :: acc)
      | Open :: rest ->
          let sublist, rest = build_loop tokens in
          let cond = build_ast sublist in
          loop rest ((Loop cond) :: acc)
      | Close :: rest -> failwith "Close should have been consumed by build_loop"
      | (Comment _) :: rest -> loop rest acc
    in
    List.rev (loop tokens [])
  
  (** builds the AST from a char stream *)
  let parse stream =
    let tokens = tokenize stream in
    build_ast tokens
end

module Optimizer = struct
  open ParseTree
  
  let rec group = function
    | Move a :: Move b :: rest -> 
      let lst = if a + b <> 0 then Move (a + b) :: rest else rest in  
      group lst
    | Add a :: Add b :: rest -> 
      let lst = if a + b <> 0 then Add (a + b) :: rest else rest in  
      group lst
    | Loop a :: rest -> (Loop (group a)) :: (group rest)
    | other :: rest -> other :: (group rest)
    | [] -> []

  (** replace known loops with faster operations *)
  let rec unroll ast =
    let replace = function
      | [Add (-1)] -> Reset
      | [Move a; Add n; Move b; Add (-1)] 
        when a = -b -> AddMultToCell (n, a)
      | [Move a; Add n1; Move b; Add n2; Move c; Add (-1)] 
        when a + b = -c && n1 = n2 -> AddMultToCell2 (n1, a, a+b)           
      | other -> Loop (unroll other)
    in 
    match ast with
    | Loop ops :: rest -> replace ops :: unroll rest  
    | other :: rest -> other :: (unroll rest)
    | [] -> []

  (** replace move, add and revert back to a distant add *)
  let in_place_adds ast =
    let rec loop = function
      | Move i :: Add n :: Move j :: rest 
        when i = -j -> AddTo (n, i) :: loop rest
      | Loop ops :: rest -> Loop (loop ops) :: loop rest  
      | other :: rest -> other :: loop rest
      | [] -> []
    in loop ast

  (** some mult are copies, don't need to reset *)
  let replace_moves_with_copy ast =
    let rec loop = function
      | AddMultToCell2(n1, i, j) :: Move(k) :: AddMultToCell(n2, l) :: rest 
        when n1 = n2 && j = k && j = -l -> CopyMultToCell (n1, i) :: Move j :: loop rest
      | Loop ops :: rest -> Loop (loop ops) :: loop rest  
      | other :: rest -> other :: loop rest
      | [] -> []
    in loop ast
   
   (** more unrolling *)    
   let rec shortcut_loops ast =
    let replace = function
        (* we need to keep the loop in case the current cell is already 0 ! *)
      | [Move a; Reset; Move b; Add (-1)] 
        when a = -b -> Loop [Move a; Reset; Move b; Reset] 
      | other -> Loop (shortcut_loops other)
    in 
    match ast with
    | Loop ops :: rest -> replace ops :: shortcut_loops rest  
    | other :: rest -> other :: (shortcut_loops rest)
    | [] -> []

  let (<<) f1 f2 = fun x -> f1 (f2 x)

  let optimize = 
    let first_pass = replace_moves_with_copy << in_place_adds << unroll << group
    and second_pass = group << shortcut_loops
    in second_pass << first_pass
  
end

(** Runs the AST *)
module Interpreter = struct
  open ParseTree
  
  let memory = Array.make 30_000 0
  let pointer = ref 0
  
  let rec exec ast =
    let exec_node = function
      | Move i -> pointer := !pointer + i
      | Add i -> memory.(!pointer) <- memory.(!pointer) + i
      | Output -> Printf.printf "%c%!" (char_of_int memory.(!pointer))
      | Input ->
          let c = input_char stdin in
          memory.(!pointer) <- int_of_char c
      | Loop nodes ->
          while memory.(!pointer) <> 0 do
            exec nodes;
          done
      (* optimizations *)
      | Reset -> memory.(!pointer) <- 0
      | AddMultToCell (n, i) -> 
        memory.(!pointer + i) <- memory.(!pointer + i) + (memory.(!pointer)*n);
        memory.(!pointer) <- 0 
      | AddMultToCell2 (n, i, j) ->
        memory.(!pointer + i) <- memory.(!pointer + i) + (memory.(!pointer)*n);
        memory.(!pointer + j) <- memory.(!pointer + j) + (memory.(!pointer)*n);
        memory.(!pointer) <- 0 
      | CopyMultToCell (n, i) -> 
        memory.(!pointer + i) <- memory.(!pointer + i) + (memory.(!pointer)*n)
      | AddTo (n, i) -> memory.(!pointer + i) <- memory.(!pointer + i) + n
    in
    List.iter exec_node ast
end

let brainfuck optimize dump filename =
  let stream = Stream.of_channel (open_in filename) in
  let ast = Parser.parse stream in
  let ast' = if !optimize then Optimizer.optimize ast else ast in
  if !dump then ParseTree.dump ast' else Interpreter.exec ast'

let _ =
  let dump = ref false in
  let optimize = ref false in
  let args = [
    ("-optimize", Arg.Set optimize, "optimize ast");
    ("-dump", Arg.Set dump, "dump ast, don't execute");
    ] in
  Arg.parse args (brainfuck optimize dump) "usage"