ollio / oll.ml

exception Bad_line;;
exception Unknown;;

type r = {
  mutable pos_pos : int;
  mutable pos_neg : int;
  mutable neg_pos : int;
  mutable neg_neg : int;
};;

let new_result () =
  { pos_pos = 0; pos_neg = 0; neg_pos = 0; neg_neg = 0; };;

let update_result r y_bar sign = match y_bar, sign with
  | 1,1 ->    let i = r.pos_pos in r.pos_pos <- i + 1;
  | 1,-1 ->   let i = r.pos_neg in r.pos_neg <- i + 1;
  | -1,1 ->   let i = r.neg_pos in r.neg_pos <- i + 1;
  | -1,-1 ->  let i = r.neg_neg in r.neg_neg <- i + 1;
  | _,_ ->    raise (Invalid_argument "only 1 or -1 is accepted as result");;

let print_result r =
  let hit = r.pos_pos + r.neg_neg in
  let wrong = r.pos_neg + r.neg_pos in
  let acc = 100.0 *. (float_of_int hit) /. (float_of_int (hit + wrong)) in
  Printf.printf "Accuracy %f%% (%d/%d)\n" acc hit (wrong+hit);
  Printf.printf "(Answer, Predict): (p,p):%d (p,n):%d (n,p):%d (n,n):%d\n"
    r.pos_pos r.pos_neg r.neg_pos r.neg_neg;;

(* w x -> 1 | -1 *)
let predict w x =
  let y = Vector.prod w x in
  if y > 0.0 then 1 else -1;;

(* vector -> float -> flost *)
let pa x l =
  l /. (Vector.norm2 x);;

let pa1 c x l =
  let v = pa x l in
  if v < c then v else c;;

let pa2 c x l =
  l /. ( (Vector.norm2 x) +. 1.0 /. (2.0 *. c));;

(* "PA"|"PA1"|"PA2" -> ( vector -> float -> flost *)
let tau_getter = function
  | "PA" ->  pa;
  | "PA1" -> pa1 1.0;
  | "PA2" -> pa2 1.0;
  | _ -> raise (Invalid_argument "PA/PA1/PA2 is only accepted as learning algorithm");;

let loss y_t w x =
  let v = 1.0 -. (Vector.prod w x) *. (float_of_int y_t) in
  if v < 0.0 then 0.0 else v;;

(* int_of_string does not accept "+1" *)
let my_int_of_string s =
  if String.get s 0 == '+' then
    int_of_string (Str.string_after s 1)
  else
    int_of_string s;;

let parse_line line =
  let l2p = function
    | i::v::[] -> ((my_int_of_string i), (float_of_string v));
    | _ -> raise Bad_line
  in
  match Str.split (Str.regexp " ") line with
    | [] -> raise Bad_line;
    | hd::tl ->
      let l = List.map (Str.split (Str.regexp ":")) tl in
      (my_int_of_string hd), (List.map l2p l);;

type line_t = Line of string | EOF;;

let get_line input_in =
  try
    Line(input_line input_in)
  with
      End_of_file -> EOF;;


(* "PA"|"PA1"|"PA2" -> string -> vector *)
let train t file =
  let input_in = open_in file in
  let f_tau = tau_getter t in
  let rec train_ w =
    match get_line input_in with
      | Line(line) ->
      let sign,x = parse_line line in
      let tau = f_tau x (loss sign w x) in
      train_ (Vector.add w (Vector.s_prod (tau *. float_of_int sign) x))
      | EOF ->
	close_in input_in;
	w
  in
  train_ (Vector.newv());;

(* "PA"|"PA1"|"PA2" -> vector -> file -> Oll.result *)
let test ww t file = 
  let input_in = open_in file in
  let f_tau = tau_getter t in
  let r = new_result () in
  let rec test_ w =
    match get_line input_in with
      | Line(line) ->
	let sign,x = parse_line line in
	let y_bar = predict w x in
	let tau = f_tau x (loss sign w x) in
	update_result r y_bar sign;
(*      Printf.printf "(y_bar sign)=(%d, %d) (%f)\n" y_bar sign tau; *)
(*      Vector.print_vector x;
      Vector.print_vector w;
      Vector.print_vector (Vector.add w (Vector.s_prod (tau *. float_of_int sign) x)); *)
	test_ (Vector.add w (Vector.s_prod (tau *. float_of_int sign) x));
      | EOF ->
	close_in input_in;
	r
  in
  test_ ww;;
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.