Oliver Gu avatar Oliver Gu committed 7c7726f

Added n-fold cross validation mode to svm cli.

Comments (0)

Files changed (5)

+2013-02-24: Added n-fold cross validation mode to svm command line interface.
+
 2013-02-22: Svm node arrays for problems (training sets) without a precomputed
             kernel are sparsely represented and do not store attributes which
             have value zero.
 OASISFormat:       0.3
 Name:              libsvm-ocaml
-Version:           0.8.2
+Version:           0.8.3
 Synopsis:          libsvm-ocaml - OCaml bindings to the LIBSVM library
 Description:       libsvm-ocaml offers an OCaml-interface to the LIBSVM library
 Authors:           Oliver Gu <odietric@gmail.com>

examples/svm_cli.ml

         ~doc:" turn this on when shrinking heuristics should not be used"
       +> flag "-p" no_arg
         ~doc:" turn this on to train a svc or svr model with probability estimates"
+      +> flag "-v N" (optional int)
+        ~doc:" N-fold cross validation mode"
       +> flag "-q" no_arg
         ~doc:" quiet mode (no ouputs)"
       +> anon ("TRAINING-SET-FILE" %: file)
-      +> anon ("MODEL-FILE" %: file)
+      +> anon (maybe ("MODEL-FILE" %: file))
     )
     (fun svm_type kernel degree gamma coef0 c nu eps cachesize tol
-      turn_shrinking_off probability quiet training_set_file model_file () ->
+      turn_shrinking_off probability n_folds quiet training_set_file model_file () ->
         match Result.try_with (fun () -> Svm.Problem.load training_set_file) with
         | Error exn ->
           prerr_endline (Exn.to_string exn);
           exit 1
         | Ok problem ->
-          let model = Svm.train
-            ?svm_type
-            ?kernel
-            ?degree
-            ?gamma
-            ?coef0
-            ?c
-            ?nu
-            ?eps
-            ?cachesize
-            ?tol
-            ~shrinking:(if turn_shrinking_off then `off else `on)
-            ~probability
-            ~verbose:(not quiet)
-            problem
-          in
-          Svm.Model.save model model_file)
+          match n_folds with
+          | None ->
+            let model = Svm.train
+              ?svm_type
+              ?kernel
+              ?degree
+              ?gamma
+              ?coef0
+              ?c
+              ?nu
+              ?eps
+              ?cachesize
+              ?tol
+              ~shrinking:(if turn_shrinking_off then `off else `on)
+              ~probability
+              ~verbose:(not quiet)
+              problem
+            in
+            let model_file = Option.value model_file
+              ~default:(sprintf "%s.model" training_set_file)
+            in
+            Svm.Model.save model model_file
+          | Some n_folds ->
+            let predicted = Svm.cross_validation
+              ?svm_type
+              ?kernel
+              ?degree
+              ?gamma
+              ?coef0
+              ?c
+              ?nu
+              ?eps
+              ?cachesize
+              ?tol
+              ~shrinking:(if turn_shrinking_off then `off else `on)
+              ~probability
+              ~verbose:(not quiet)
+              ~n_folds
+              problem
+            in
+            let expected = Svm.Problem.get_targets problem in
+            match Option.value svm_type ~default:`C_SVC with
+            | `C_SVC | `NU_SVC | `ONE_CLASS ->
+              let accuracy = Stats.calc_accuracy ~expected ~predicted in
+              printf "Cross Validation Accuracy = %g%%\n" (100. *. accuracy)
+            | `EPSILON_SVR | `NU_SVR ->
+              let mse = Stats.calc_mse ~expected ~predicted in
+              let scc = Stats.calc_scc ~expected ~predicted in
+              printf "Cross Validation Mean squared error = %g\n" mse;
+              printf "Cross Validation Squared correlation coefficient = %g\n" scc)
 
 let predict_cmd =
   Command.basic ~summary:"svm prediction"
         | `EPSILON_SVR | `NU_SVR ->
           let mse = Stats.calc_mse ~expected ~predicted in
           let scc = Stats.calc_scc ~expected ~predicted in
-          printf "Mean sqared error = %g (regression)\n" mse;
+          printf "Mean squared error = %g (regression)\n" mse;
           printf "Squared correlation coefficient = %g (regression)\n" scc)
 
 let () =
   Exn.handle_uncaught ~exit:true (fun () ->
-    Command.run ~version:"0.8" ~build_info:"N/A"
+    Command.run ~version:"0.8.3" ~build_info:"N/A"
       (Command.group ~summary:"Command line tools for Libsvm"
          [ "scale"  , scale_cmd
          ; "train"  , train_cmd
 # OASIS_START
-# DO NOT EDIT (digest: 28061f77c5ec5c3f4a2e296ef62d5f1f)
-version = "0.8.2"
+# DO NOT EDIT (digest: f7306597525e6b4b2e4656b10873e706)
+version = "0.8.3"
 description = "libsvm-ocaml - OCaml bindings to the LIBSVM library"
 requires = "core lacaml"
 archive(byte) = "svm.cma"
 (* setup.ml generated for the first time by OASIS v0.2.0 *)
 
 (* OASIS_START *)
-(* DO NOT EDIT (digest: fc0701cc2c8fa07c3173bbbee1259fcc) *)
+(* DO NOT EDIT (digest: 7c30cd7ad01d50e144f5d9dc1e019b4f) *)
 (*
    Regenerated by OASIS v0.3.0
    Visit http://oasis.forge.ocamlcore.org for more information and
           ocaml_version = Some (OASISVersion.VGreaterEqual "3.12");
           findlib_version = Some (OASISVersion.VGreaterEqual "1.3.1");
           name = "libsvm-ocaml";
-          version = "0.8.2";
+          version = "0.8.3";
           license =
             OASISLicense.DEP5License
               (OASISLicense.DEP5Unit
           };
      oasis_fn = Some "_oasis";
      oasis_version = "0.3.0";
-     oasis_digest = Some ">\128\233\017H\170\192A\131.T\246\140~f\191";
+     oasis_digest = Some "\030t\203\178\015\168UF\128&\170\215\179\0244\230";
      oasis_exec = None;
      oasis_setup_args = [];
      setup_update = false;
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.