Source

ollio / oll.ml

Full commit
exception Bad_line;;
exception Unknown;;

type r = {
  mutable pos_pos : int;
  mutable pos_neg : int;
  mutable neg_pos : int;
  mutable neg_neg : int;
};;

let new_result () =
  { pos_pos = 0; pos_neg = 0; neg_pos = 0; neg_neg = 0; };;

let update_result r y_bar sign = match y_bar, sign with
  | 1,1 ->
    let i = r.pos_pos in r.pos_pos <- i + 1;
  | 1,-1 ->
    let i = r.pos_neg in r.pos_neg <- i + 1;
  | -1,1 ->
    let i = r.neg_pos in r.neg_pos <- i + 1;
  | -1,-1 ->
    let i = r.neg_neg in r.neg_neg <- i + 1;
  | _,_ ->
    raise Unknown;;

let print_result r =
  Printf.printf "TP: %d  TN: %d\n" r.pos_pos r.neg_neg;
  Printf.printf "FP: %d  FN: %d\n" r.neg_pos r.pos_neg;;

(* w x -> 1 | -1 *)
let judge w x =
  let y = Vector.prod w x in
  if y > 0.0 then 1
  else -1;;

(* vector -> float -> flost *)
let pa x l =
  l /. (Vector.norm2 x);;

let pa1 c x l =
  let v = pa x l in
  if v < c then v else c;;

let pa2 c x l =
  l /. ( (Vector.norm2 x) +. 1.0 /. (2.0 *. c));;

(* "PA"|"PA1"|"PA2" -> ( vector -> float -> flost *)
let tau_getter = function
  | "PA" ->  pa;
  | "PA1" -> pa1 1.0;
  | "PA2" -> pa2 1.0;
  | _ -> raise Unknown;;

let loss y_bar ans =
  let v = 1.0 -. float_of_int (ans * y_bar) in
  if v < 0.0 then 0.0 else v;;

(* int_of_string does not accept "+1" *)
let my_int_of_string s =
  if String.get s 0 == '+' then
    int_of_string (Str.string_after s 1)
  else
    int_of_string s;;

let parse_line line =
  let l2p = function
    | i::v::[] -> ((my_int_of_string i), (float_of_string v));
    | _ -> raise Bad_line
  in
  match Str.split (Str.regexp " ") line with
    | [] -> raise Bad_line;
    | hd::tl ->
      let l = List.map (Str.split (Str.regexp ":")) tl in
(*      print_int judge; *)
      (my_int_of_string hd), (List.map l2p l);;

(* "PA"|"PA1"|"PA2" -> string -> vector *)
let train t file =
  let input_in = open_in file in
  let f_tau = tau_getter t in
  let r = new_result () in
  let rec train_ w =
    try
      let sign,x = parse_line (input_line input_in) in
      let y_bar = judge w x in
      let tau = f_tau x (loss y_bar sign) in
      update_result r y_bar sign;
      train_ (Vector.add w (Vector.s_prod tau x))
    with
	End_of_file ->
	  close_in input_in;
	  w
  in
  train_ (Vector.newv());;

(* "PA"|"PA1"|"PA2" -> vector -> file -> Oll.result *)
let test ww t file = 
  let input_in = open_in file in
  let f_tau = tau_getter t in
  let r = new_result () in
  let rec test_ w =
    try
      let sign,x = parse_line (input_line input_in) in
      let y_bar = judge w x in
      let tau = f_tau x (loss y_bar sign) in
      update_result r y_bar sign;
      test_ (Vector.add w (Vector.s_prod tau x))
    with
	End_of_file ->
	  close_in input_in;
	  r
  in
  test_ ww;;