ocaml-minigames / mills / ai_alphabeta.ml

open Mill

type move =
  | Init
  | Put of int (* index *)
  | Move of int * int (* from, goal *)
  | Fly of int * int (* from, goal *)
  | Capture of int (* optional capture index *)

let string_of_move = function
  | Put i -> Printf.sprintf "put %d" i
  | Move (f, g) -> Printf.sprintf "move %d %d" f g
  | Fly (f, g) -> Printf.sprintf "fly %d %d" f g
  | Capture i -> Printf.sprintf "capture %d" i
  | Init -> "init move"

type node = Node of state * move * node lazy_t list (* state, move, children *)

let string_of_node = function
  | Node (_, move, _) -> "Node(_, " ^ (string_of_move move) ^ ", _)"

let can_capture state move = match move with
  | Put i -> in_mill state i
  | Move (from, goal) -> in_mill state goal
  | Fly (from, goal) -> in_mill state goal
  | Capture _ -> false
  | Init -> false

let rec build state last_move depth =
  let build_put () =
    let aux i =
      let s = copy state in
      put s i;
      end_of_turn s;
      build s (Put i) (pred depth)
    in
    let children = List.map aux (free_dots state) in
    Node (state, last_move, children)
  in
  let build_fly () =
    let aux (from, goal) =
      let s = copy state in
      fly s from goal;
      end_of_turn s;
      build s (Fly (from, goal)) (pred depth)
    in
    let children = List.map aux (all_flies state) in
    Node (state, last_move, children)
  in
  let build_move () =
    let aux (from, goal) =
      let s = copy state in
      move s from goal;
      end_of_turn s;
      build s (Move (from, goal)) (pred depth)
    in
    let children = List.map aux (all_moves state) in
    Node (state, last_move, children)
  in
  let build_capture () =
    let aux i =
      let s = copy state in
      capture s i;
      end_of_turn s;
      build s (Capture i) (pred depth)
    in
    unroll_turn state;
    let opponent = get_color (succ (get_turn state)) in
    let children = List.map aux (capturables state opponent) in
    Node (state, last_move, children)
  in
  if can_capture state last_move then Lazy.lazy_from_fun build_capture
  else if depth <= 0 then lazy (Node (state, last_move, []))
  else if can_put state then Lazy.lazy_from_fun build_put
  else if can_fly state then Lazy.lazy_from_fun build_fly
  else Lazy.lazy_from_fun build_move

let force_count = ref 0
let force thunk = 
  incr force_count;
  Lazy.force thunk
  
let print_force_stat () = Printf.eprintf "force called %d times\n%!" !force_count
let reset_force () = force_count := 0  

let force_tree root =
  let rec f node = match force node with
      Node (state, move, list) ->
        List.iter f list
  in
  f root

let print_tree root =
  let rec indent depth =
    if depth > 0 then (Printf.printf "+"; indent (pred depth))
  in
  let rec p depth node = match force node with
      Node (state, move, list) ->
        indent depth;
        Printf.printf "%s\n" (string_of_move move);
        List.iter (p (succ depth)) list
  in
  p 0 root

let count_positions node =
  let rec count node = match force node with
    | Node (_, _, []) -> 1
    | Node (_, _, children) ->
        let counts = List.map count children in
        List.fold_left (+) 0 counts
  in
  count node

let rec alphabeta heuristic node alpha beta =
  let min_score lst =
    let score = ref max_int in
    let rec loop beta = function
      | [] -> !score
      | t:: q ->
          score := min !score (alphabeta heuristic t alpha beta);
          if alpha >= !score then !score
          else loop (min beta !score) q
    in
    loop beta lst
  in
  let max_score lst =
    let score = ref min_int in
    let rec loop alpha = function
      | [] -> !score
      | t:: q ->
          score := max !score (alphabeta heuristic t alpha beta);
          if !score >= beta then !score
          else loop (max alpha !score) q
    in
    loop alpha lst
  in
  match force node with
  | Node (s, move, []) -> heuristic s
  | Node (s, move, children) ->
      if (get_turn s) mod 2 = 0 then min_score children
      else max_score children

let score heuristic node = alphabeta heuristic node min_int max_int

let estimate state =
  let black = (get_count state Black)
  and white = (get_count state White) in
  if not (can_put state) then
    if black > 2 && white <= 2 then max_int
    else if black <= 2 && white > 2 then min_int
    else black - white
  else black - white

let player =
  let find_best state lastmove depth =
    reset_force ();
    let root = build state lastmove depth in
    reset_force ();
    let children = match force root with Node (_, _, c) -> c in
    let scores = List.map (score estimate) children in
    let better = if (get_turn state) mod 2 = 0 then (<) else (>) in
    let init = if (get_turn state) mod 2 = 0 then max_int else min_int in
    let rec select nodes scores selected score =
      if nodes = [] then selected
      else
        let n, s = (List.hd nodes), (List.hd scores) in
        if better s score then select (List.tl nodes) (List.tl scores) n s
        else select (List.tl nodes) (List.tl scores) selected score
    in
    let thunk = select children scores (lazy (Node (state, lastmove, []))) init
    in 
    print_force_stat ();
    force thunk    
  in
  
  object (self)
    val mutable last_move = Init
    
    method put state =
      let depth = (* reduced when too few stones are present *)
        if get_count state Black > 3 && get_count state White > 3 then 7
        else 4
      in 
      let selected = find_best state Init 5 in
      match selected with
      | Node(_, Put i, _) -> last_move <- Put i ; i
      | _ -> failwith "no put found"
    method move state =
      let before = Unix.gettimeofday() in
      let depth = (* reduced when one player can jump *)
        if get_count state Black > 3 && get_count state White > 3 then 7
        else 4
      in 
      let selected = find_best state Init depth in      
      let after = Unix.gettimeofday() in
      Printf.printf "time: %f\n%!" (after -. before);
      match selected with
      | Node(_, Move (f, g), _) -> last_move <- Move (f, g) ; f, g
      | _ -> failwith "no move found"
    method fly state =
      let selected = find_best state Init 4 in
      match selected with
      | Node(_, Fly (f, g), _) -> last_move <- Fly (f, g) ; f, g
      | _ -> failwith "no fly found"
    method capture state =
      end_of_turn state; (* unrolled later *)
      let selected = find_best state last_move 4 in
      match selected with
      | Node(_, Capture i, _) -> last_move <- Capture i ; i
      | _ -> failwith "no capture found"
  end
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.