Source

ocaml-toys / maze / astar.ml

Full commit
module Pathfinding = struct
  type cost = { steps: int; estimation: int; camefrom: int * int }
  
  type node =
    | Unknown
    | Closed
    | Exploring of cost
    | Explored of cost
  
  exception No_path_found
  
  module type ExplorationType = sig
    val matrix : node array array
    val heuristic : int * int -> int
  end
  
  module MakeExploration = functor(E: ExplorationType) ->
    struct
      let matrix = E.matrix
      let heuristic = E.heuristic
      
      let is_explorable maze node =
        if not (Maze.is_passable maze node) then false
        else match matrix.(fst node).(snd node) with
          | Closed -> false
          | Unknown -> true
          | Exploring _ -> true
          | Explored _ -> false
      
      let get_steps node = match matrix.(fst node).(snd node) with
        | Exploring cost -> cost.steps
        | Explored cost -> cost.steps
        | Closed -> max_int
        | Unknown -> invalid_arg "unknown node"
      
      let set_exploring node steps camefrom =
        matrix.(fst node).(snd node) <- Exploring { steps = steps; estimation = steps + (heuristic node); camefrom = camefrom }
      
      let set_explored node =
        let explored = match matrix.(fst node).(snd node) with
          | Exploring cost -> Explored cost
          | _ -> invalid_arg ("not in exploring state: " ^ string_of_int (fst node) ^ ", " ^ string_of_int (snd node))
        in
        matrix.(fst node).(snd node) <- explored
      
      let rec reconstruct_path node path =
        let camefrom = match matrix.(fst node).(snd node) with
          | Explored cost -> cost.camefrom
          | Exploring cost -> cost.camefrom
          | _ -> node
        in
        if camefrom = node then node :: path
        else reconstruct_path camefrom (node :: path)
      
      let estimate node = match node with
        | Exploring cost -> cost.estimation
        | Explored cost -> cost.estimation
        | Unknown -> max_int
        | Closed -> max_int
      
      let compare a b =
        let pervasives = Pervasives.compare a b in
        if pervasives = 0 then 0
        else
          let score_a = estimate matrix.(fst a).(snd a) in
          let score_b = estimate matrix.(fst b).(snd b) in
          let diff = score_a - score_b in
          if diff = 0 then pervasives else diff
      
    end
  
  let find_path maze start goal =
    let module Exploration = MakeExploration(struct
        let matrix = Array.make_matrix (Maze.width maze) (Maze.height maze) Unknown
        let heuristic = Maze.distance goal
      end)
    in
    let openset = ref [start] in
    Exploration.set_exploring start 0 start;
    let rec loop () =
      if !openset = [] then raise No_path_found
      else
        openset := List.sort Exploration.compare !openset;
      let current = List.hd !openset in
      if current = goal then List.rev (Exploration.reconstruct_path goal [])
      else begin
        openset := List.tl !openset;
        Exploration.set_explored current;
        
        let neighbors = Maze.neighbor_nodes maze current in
        let explore node =
          if Exploration.is_explorable maze node then
            let tentative_steps = (Exploration.get_steps current) + 1 in
            if not (List.mem node !openset) then begin
              Exploration.set_exploring node tentative_steps current;
              openset := node :: !openset
            end
            else if tentative_steps < (Exploration.get_steps node) then Exploration.set_exploring node tentative_steps current
        in
        List.iter explore neighbors;
        loop ()
      end
    in
    loop ()
end

(*
let () =
  let maze, goal, ants = Maze.Parser.from_file "maze/maze.txt" in
  let start = List.hd ants in
  let before = Unix.gettimeofday () in
  let path = Pathfinding.find_path maze start goal in
  
  for i = 0 to 1_000 do
    ignore (Pathfinding.find_path maze start goal)
  done;
  
  let after = Unix.gettimeofday () in
  Maze.walk maze (List.rev path);
  Maze.draw maze;
  Printf.printf "%d steps\n" (List.length path);
  let time = (after -. before) *. 1000. in
  Printf.printf "%f ms\n" time
*)