libsvm-ocaml / lib / libsvm.ml

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
(*
   LIBSVM-OCaml - OCaml-bindings to the LIBSVM library

   Copyright (C) 2013-  Oliver Gu
   email: odietric@gmail.com

   Copyright (C) 2005  Dominik Brugger
   email: dominikbrugger@fastmail.fm
   WWW:   http://ocaml-libsvm.berlios.de

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Lesser General Public
   License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.

   This library is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   Lesser General Public License for more details.

   You should have received a copy of the GNU Lesser General Public
   License along with this library; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

open Core.Std
open Lacaml.D
open Printf

module Svm = struct
  type problem
  type params
  type model

  type svm_type =
  | C_SVC
  | NU_SVC
  | ONE_CLASS
  | EPSILON_SVR
  | NU_SVR

  type kernel_type =
  | LINEAR
  | POLY
  | RBF
  | SIGMOID
  | PRECOMPUTED

  type svm_params = {
    svm_type : svm_type;
    kernel_type : kernel_type;
    degree : int;
    gamma : float;
    coef0 : float;
    c : float;
    nu : float;
    eps : float;
    cachesize : float;
    tol : float;
    shrinking : bool;
    probability : bool;
    nr_weight : int;
    weight_label : int list;
    weight : float list;
  }

  module Stub = struct
    type double_array
    type svm_node
    type svm_node_array
    type svm_node_matrix

    external double_array_create :
      int -> double_array = "double_array_create_stub"
    external double_array_set :
      double_array -> int -> float -> unit = "double_array_set_stub"
    external double_array_get :
      double_array -> int -> float = "double_array_get_stub"
    external svm_node_array_create :
      int -> svm_node_array = "svm_node_array_create_stub"
    external svm_node_array_set :
      svm_node_array -> int -> int -> float -> unit = "svm_node_array_set_stub"
    external svm_node_matrix_create :
      int -> svm_node_matrix = "svm_node_matrix_create_stub"
    external svm_node_matrix_set :
      svm_node_matrix -> int -> svm_node_array -> unit = "svm_node_matrix_set_stub"
    external svm_problem_create :
      unit -> problem = "svm_problem_create_stub"
    external svm_problem_l_set :
      problem -> int -> unit = "svm_problem_l_set_stub"
    external svm_problem_l_get :
      problem -> int = "svm_problem_l_get_stub"
    external svm_problem_y_set :
      problem -> double_array -> unit = "svm_problem_y_set_stub"
    external svm_problem_y_get :
      problem -> int -> float = "svm_problem_y_get_stub"
    external svm_problem_x_set :
      problem -> svm_node_matrix -> unit = "svm_problem_x_set_stub"
    external svm_problem_x_get :
      problem -> int -> int -> (int * float) = "svm_problem_x_get_stub"
    external svm_problem_width :
      problem -> int -> int = "svm_problem_width_stub"
    external svm_problem_print :
      problem -> unit = "svm_problem_print_stub"
    external svm_param_create : svm_params -> params = "svm_param_create_stub"

    external svm_set_quiet_mode : unit -> unit = "svm_set_quiet_mode_stub"
    external svm_train : problem -> params -> model = "svm_train_stub"
    external svm_cross_validation :
      problem -> params -> int -> vec = "svm_cross_validation_stub"

    external svm_save_model : string -> model -> unit = "svm_save_model_stub"
    external svm_load_model : string -> model = "svm_load_model_stub"

    external svm_get_svm_type : model -> svm_type = "svm_get_svm_type_stub"
    external svm_get_kernel_type : model -> kernel_type = "svm_get_kernel_type_stub"
    external svm_get_nr_class : model -> int = "svm_get_nr_class_stub"
    external svm_get_labels : model -> int list = "svm_get_labels_stub"
    external svm_get_svr_probability :
      model -> float = "svm_get_svr_probability_stub"
    external svm_check_probability_model :
      model -> bool = "svm_check_probability_model_stub"

    external svm_predict_values :
      model -> svm_node_array -> float array = "svm_predict_values_stub"
    external svm_predict :
      model -> svm_node_array -> float = "svm_predict_stub"
    external svm_predict_probability :
      model -> svm_node_array -> float * float array = "svm_predict_probability_stub"
  end

  (* This functions skips all entries with zero
     value and creates a sparse svm node array. *)
  let sparse_svm_node_array_of_vec v =
    let count_nonzeros v = Vec.fold (fun count x ->
      count + if x <> 0. then 1 else 0) 0 v
    in
    let size = count_nonzeros v + 1 in
    let nodes = Stub.svm_node_array_create size in
    let pos = ref 0 in
    Vec.iteri (fun index value ->
      if value <> 0. then begin
        Stub.svm_node_array_set nodes !pos index value;
        incr pos
      end) v;
    Stub.svm_node_array_set nodes !pos (-1) 0.;
    nodes

  let svm_node_array_of_vec v =
    let n = Vec.dim v in
    let nodes = Stub.svm_node_array_create (n+1) in
    Vec.iteri (fun index value ->
      let pos = index-1 in
      Stub.svm_node_array_set nodes pos pos value) v;
    Stub.svm_node_array_set nodes n (-1) 0.;
    nodes

  let svm_node_array_of_list l ~len =
    let size = len + 1 in
    let nodes = Stub.svm_node_array_create size in
    List.iteri l ~f:(fun pos (index, value) ->
      Stub.svm_node_array_set nodes pos index value);
    Stub.svm_node_array_set nodes len (-1) 0.;
    nodes

  let count_lines file =
    In_channel.with_file file ~f:(fun ic ->
      In_channel.fold_lines ic ~init:0 ~f:(fun count _line -> count + 1))

  let parse_line file = stage (fun line ~pos ->
    let result = Result.try_with (fun () ->
      match String.rstrip line |! String.split ~on:' ' with
      | [] -> assert false
      | x :: xs ->
        let target = Float.of_string x in
        let feats = List.map xs ~f:(fun str ->
          let index, value = String.lsplit2_exn str ~on:':' in
          Int.of_string index, Float.of_string value)
        in
        target, feats)
    in
    match result with
    | Ok x -> x
    | Error exn ->
      failwithf "%s: wrong input format at line %d: %s" file pos (Exn.to_string exn) ())

  module Problem = struct
    type t = {
      n_samples : int;
      n_feats : int;
      prob : problem;
    }

    let get_n_samples t = t.n_samples
    let get_n_feats t = t.n_feats

    let create_gen x y ~f =
      let n_samples = Mat.dim1 x in
      let n_feats = Mat.dim2 x in
      let x' = Mat.transpose x in
      let m = Stub.svm_node_matrix_create n_samples in
      let v = Stub.double_array_create n_samples in
      for i = 1 to n_samples do
        let x_row = Mat.col x' i in
        Stub.svm_node_matrix_set m (i-1) (f x_row);
        Stub.double_array_set v (i-1) y.{i}
      done;
      let prob = Stub.svm_problem_create () in
      Stub.svm_problem_l_set prob n_samples;
      Stub.svm_problem_x_set prob m;
      Stub.svm_problem_y_set prob v;
      { n_samples;
        n_feats;
        prob;
      }

    let create ~x ~y = create_gen x y ~f:sparse_svm_node_array_of_vec

    let create_k ~k ~y = create_gen k y ~f:svm_node_array_of_vec

    let load file =
      let n_samples = count_lines file in
      let n_feats = ref 0 in
      let x = Stub.svm_node_matrix_create n_samples in
      let y = Stub.double_array_create n_samples in
      In_channel.with_file file ~f:(fun ic ->
        let parse_line = unstage (parse_line file) in
        let rec loop i =
          match In_channel.input_line ic with
          | None -> ()
          | Some line ->
            let target, feats = parse_line line ~pos:i in
            Stub.double_array_set y (i-1) target;
            let len = List.length feats in
            Stub.svm_node_matrix_set x (i-1) (svm_node_array_of_list feats ~len);
            n_feats := max !n_feats len;
            loop (i+1)
        in
        loop 1);
      let prob = Stub.svm_problem_create () in
      Stub.svm_problem_l_set prob n_samples;
      Stub.svm_problem_x_set prob x;
      Stub.svm_problem_y_set prob y;
      { n_samples;
        n_feats = !n_feats;
        prob;
      }

    let get_targets t =
      let n = t.n_samples in
      let y = Vec.create n in
      for i = 1 to n do
        y.{i} <- Stub.svm_problem_y_get t.prob (i-1)
      done;
      y

    let output t oc =
      let buf = Buffer.create 1024 in
      for i = 0 to t.n_samples-1 do
        Buffer.add_string buf (sprintf "%g" (Stub.svm_problem_y_get t.prob i));
        let width = Stub.svm_problem_width t.prob i in
        for j = 0 to width-1 do
          let index, value = Stub.svm_problem_x_get t.prob i j in
          Buffer.add_string buf (sprintf " %d:%g" index value);
        done;
        Buffer.add_char buf '\n';
        Out_channel.output_string oc (Buffer.contents buf);
        Buffer.clear buf
      done;
      Out_channel.flush oc

    let save t file = Out_channel.with_file file ~f:(fun oc -> output t oc)

    let min_max_feats t =
      let min_feats = Vec.make t.n_feats Float.infinity in
      let max_feats = Vec.make t.n_feats Float.neg_infinity in
      for i = 0 to t.n_samples-1 do
        let width = Stub.svm_problem_width t.prob i in
        for j = 0 to width-1 do
          let index, value = Stub.svm_problem_x_get t.prob i j in
          min_feats.{index} <- Float.min min_feats.{index} value;
          max_feats.{index} <- Float.max max_feats.{index} value;
        done;
      done;
      (`Min min_feats, `Max max_feats)

    let scale ?(lower= -.1.) ?(upper=1.) t ~min_feats ~max_feats =
      let n_samples = t.n_samples in
      let x = Stub.svm_node_matrix_create n_samples in
      let y = Stub.double_array_create n_samples in
      for i = 0 to n_samples-1 do
        let width = Stub.svm_problem_width t.prob i in
        let nodes = Stub.svm_node_array_create (width+1) in
        for j = 0 to width-1 do
          let index, value = Stub.svm_problem_x_get t.prob i j in
          if Float.(=.) value min_feats.{index} then
            Stub.svm_node_array_set nodes j index lower
          else if Float.(=.) value max_feats.{index} then
            Stub.svm_node_array_set nodes j index upper
          else
            let new_value = lower +. (upper-.lower) *.
              (value-.min_feats.{index}) /.
              (max_feats.{index}-.min_feats.{index})
            in
            Stub.svm_node_array_set nodes j index new_value
        done;
        Stub.svm_node_array_set nodes width (-1) 0.;
        Stub.svm_node_matrix_set x i nodes;
        Stub.double_array_set y i (Stub.svm_problem_y_get t.prob i);
      done;
      let scaled_prob = Stub.svm_problem_create () in
      Stub.svm_problem_l_set scaled_prob n_samples;
      Stub.svm_problem_x_set scaled_prob x;
      Stub.svm_problem_y_set scaled_prob y;
      { n_samples;
        n_feats = t.n_feats;
        prob = scaled_prob;
      }

    let print t = Stub.svm_problem_print t.prob
  end

  module Model = struct
    type t = model

    let get_svm_type t =
      match Stub.svm_get_svm_type t with
      | C_SVC       -> `C_SVC
      | NU_SVC      -> `NU_SVC
      | ONE_CLASS   -> `ONE_CLASS
      | EPSILON_SVR -> `EPSILON_SVR
      | NU_SVR      -> `NU_SVR

    let get_n_classes t = Stub.svm_get_nr_class t

    let get_labels t =
      match Stub.svm_get_svm_type t with
      | NU_SVR | EPSILON_SVR | ONE_CLASS ->
        invalid_arg "Cannot return labels for a regression or one-class model."
      | _ -> Stub.svm_get_labels t

    let get_svr_probability t =
      match Stub.svm_get_svm_type t with
      | EPSILON_SVR | NU_SVR -> Stub.svm_get_svr_probability t
      | _ -> invalid_arg "The model is no regression model."

    let save t filename = Stub.svm_save_model filename t
    let load filename = Stub.svm_load_model filename
  end

  let create_params ~svm_type ~kernel ~degree ~gamma ~coef0 ~c
      ~nu ~eps ~cachesize ~tol ~shrinking ~probability ~weights =
    let svm_type = match svm_type with
      | `C_SVC       -> C_SVC
      | `NU_SVC      -> NU_SVC
      | `ONE_CLASS   -> ONE_CLASS
      | `EPSILON_SVR -> EPSILON_SVR
      | `NU_SVR      -> NU_SVR
    in
    let kernel_type = match kernel with
      | `LINEAR      -> LINEAR
      | `POLY        -> POLY
      | `RBF         -> RBF
      | `SIGMOID     -> SIGMOID
      | `PRECOMPUTED -> PRECOMPUTED
    in
    let shrinking = match shrinking with
      | `on  -> true
      | `off -> false
    in
    let weight_label, weight = List.unzip weights in
    Stub.svm_param_create {
      svm_type;
      kernel_type;
      degree;
      gamma;
      coef0;
      c;
      nu;
      eps;
      cachesize;
      tol;
      shrinking;
      probability;
      nr_weight = List.length weight;
      weight_label;
      weight;
    }

  let train
      ?(svm_type=`C_SVC)
      ?(kernel=`RBF)
      ?(degree=3)
      ?(gamma=0.)
      ?(coef0=0.)
      ?(c=1.)
      ?(nu=0.5)
      ?(eps=0.1)
      ?(cachesize=100.)
      ?(tol=1e-3)
      ?(shrinking=`on)
      ?(probability=false)
      ?(weights=[])
      ?(verbose=false)
      problem =
    let params = create_params
      ~gamma:(1. /. float problem.Problem.n_feats)
      ~svm_type ~kernel ~degree ~coef0 ~c ~nu ~eps
      ~cachesize ~tol ~shrinking ~probability ~weights
    in
    if not verbose then Stub.svm_set_quiet_mode () else ();
    Stub.svm_train problem.Problem.prob params

  let cross_validation
      ?(svm_type=`C_SVC)
      ?(kernel=`RBF)
      ?(degree=3)
      ?(gamma=0.)
      ?(coef0=0.)
      ?(c=1.)
      ?(nu=0.5)
      ?(eps=0.1)
      ?(cachesize=100.)
      ?(tol=1e-3)
      ?(shrinking=`on)
      ?(probability=false)
      ?(weights=[])
      ?(verbose=false)
      ~n_folds problem =
    let params = create_params
      ~gamma:(1. /. float problem.Problem.n_feats)
      ~svm_type ~kernel ~degree ~coef0 ~c ~nu ~eps
      ~cachesize ~tol ~shrinking ~probability ~weights
    in
    if not verbose then Stub.svm_set_quiet_mode () else ();
    Stub.svm_cross_validation problem.Problem.prob params n_folds

  let predict_one model ~x =
    let nodes = match Stub.svm_get_kernel_type model with
    | PRECOMPUTED -> svm_node_array_of_vec x
    | _ -> sparse_svm_node_array_of_vec x
    in
    Stub.svm_predict model nodes

  let predict model ~x =
    let n = Mat.dim1 x in
    let y = Vec.create n in
    let x' = Mat.transpose x in
    for i = 1 to n do
      y.{i} <- predict_one model ~x:(Mat.col x' i)
    done;
    y

  let predict_values model ~x =
    let nodes = match Stub.svm_get_kernel_type model with
      | PRECOMPUTED -> svm_node_array_of_vec x
      | _ -> sparse_svm_node_array_of_vec x
    in
    let dec_vals = Stub.svm_predict_values model nodes in
    match Stub.svm_get_svm_type model with
    | EPSILON_SVR | NU_SVR | ONE_CLASS ->
      Array.make_matrix 1 1 dec_vals.(0)
    | C_SVC | NU_SVC ->
      let n_classes = Stub.svm_get_nr_class model in
      let dec_mat = Array.make_matrix n_classes n_classes 0. in
      let count = ref 0 in
      for i = 0 to n_classes-1 do
        for j = i+1 to n_classes-1 do
          dec_mat.(i).(j) <-   dec_vals.(!count);
          dec_mat.(j).(i) <- -.dec_vals.(!count);
          incr count
        done
      done;
      dec_mat

  let predict_probability model ~x =
    match Stub.svm_get_svm_type model with
    | EPSILON_SVR | NU_SVR ->
      invalid_arg "For probability estimates call Model.get_svr_probability."
    | ONE_CLASS ->
      invalid_arg "One-class problems do not support probability estimates."
    | C_SVC | NU_SVC ->
      if Stub.svm_check_probability_model model then
        let nodes = match Stub.svm_get_kernel_type model with
        | PRECOMPUTED -> svm_node_array_of_vec x
        | _ -> sparse_svm_node_array_of_vec x
        in
        Stub.svm_predict_probability model nodes
      else
        invalid_arg "Model does not support probability estimates."

  let predict_from_file model file =
    let n_samples = count_lines file in
    let expected = Vec.create n_samples in
    let predicted = Vec.create n_samples in
    In_channel.with_file file ~f:(fun ic ->
      let parse_line = unstage (parse_line file) in
      let rec loop i =
        match In_channel.input_line ic with
        | None -> (`Expected expected, `Predicted predicted)
        | Some line ->
          let target, feats = parse_line line ~pos:i in
          expected.{i} <- target;
          let nodes = svm_node_array_of_list feats ~len:(List.length feats) in
          predicted.{i} <- Stub.svm_predict model nodes;
          loop (i+1)
      in
      loop 1)
end

module Stats = struct

  let check_dimension x y ~location =
    let dimx = Vec.dim x in
    let dimy = Vec.dim y in
    if dimx <> dimy then
      invalid_argf "dimension mismatch in Stats.%s: %d <> %d" location dimx dimy ()
    else ()

  let calc_n_correct ~expected ~predicted =
    check_dimension expected predicted ~location:"calc_n_correct";
    Vec.fold (fun count x -> count + if x = 0. then 1 else 0) 0
      (Vec.sub expected predicted)

  let calc_accuracy ~expected ~predicted =
    check_dimension expected predicted ~location:"calc_accuracy";
    let l = Vec.dim expected in
    let n_correct = calc_n_correct ~expected ~predicted in
    Float.(of_int n_correct / of_int l)

  let calc_mse ~expected ~predicted =
    check_dimension expected predicted ~location:"calc_mse";
    let l = Vec.dim expected in
    Vec.ssqr_diff predicted expected /. float l

  let calc_scc ~expected ~predicted =
    check_dimension expected predicted ~location:"calc_scc";
    let array_x = Vec.to_array predicted in
    let array_y = Vec.to_array expected in   (* true values *)
    let sum_x  = ref 0. in
    let sum_y  = ref 0. in
    let sum_xx = ref 0. in
    let sum_yy = ref 0. in
    let sum_xy = ref 0. in
    Array.iter2_exn array_x array_y ~f:(fun x y ->
      sum_x  := !sum_x +. x;
      sum_y  := !sum_y +. y;
      sum_xx := x *. x;
      sum_yy := y *. y;
      sum_xy := x *. y;
    );
    let sqr x = x *. x in
    let l = float (Vec.dim expected) in
    Float.(sqr (l * !sum_xy - !sum_x * !sum_y) /
             ((l * !sum_xx - sqr !sum_x) * (l * !sum_yy - sqr !sum_y)))
end
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.