Markus Mottl avatar Markus Mottl committed 5666448

Added initial version of hyper parameter optimization

Comments (0)

Files changed (3)

 
       include Shared
     end
+
+    module Optim = struct
+      module Solution = struct
+        type t = {
+          kernel : Spec.Eval.Kernel.t;
+          sigma2 : float;
+          coeffs : vec;
+          log_evidence : float;
+        }
+      end
+
+      let solve ?kernel ?sigma2 ?inducing ?n_inducing ~inputs ~targets =
+        let kernel =
+          match kernel with
+          | None -> Eval_inputs.create_default_kernel inputs
+          | Some kernel -> kernel
+        in
+        let sigma2 =
+          match sigma2 with
+          | None -> Vec.sqr_nrm2 targets
+          | Some sigma2 -> max sigma2 min_float
+        in
+        let eval_inducing_prepared =
+          match inducing with
+          | None ->
+              let n_inducing =
+                let n_inputs = Spec.Eval.Inputs.get_n_inputs inputs in
+                match n_inducing with
+                | None -> n_inputs / 10
+                | Some n_inducing -> max (min n_inputs n_inducing) 0
+              in
+              Eval_inducing.Prepared.choose_random_n_inputs
+                kernel ~n_inducing inputs
+          | Some inducing -> Eval_inducing.Prepared.calc inducing
+        in
+        let eval_inputs_prepared =
+          Eval_inputs.Prepared.calc eval_inducing_prepared inputs
+        in
+        let deriv_inducing_prepared =
+          Inducing.Prepared.calc eval_inducing_prepared
+        in
+        let deriv_inputs_prepared =
+          Inputs.Prepared.calc deriv_inducing_prepared eval_inputs_prepared
+        in
+        let hyper_vars, hyper_vals = Spec.Hyper.extract kernel in
+        let n_hypers = Array.length hyper_vars in
+        let gsl_hypers = Gsl_vector.create (n_hypers + 1) in
+        gsl_hypers.{0} <- log sigma2;
+        for i = 1 to n_hypers do gsl_hypers.{i} <- hyper_vals.{i} done;
+        let module Gd = Gsl_multimin.Deriv in
+        let update_hypers ~gsl_hypers =
+          let sigma2 = exp gsl_hypers.{0} in
+          let hyper_vals = Vec.create n_hypers in
+          for i = 1 to n_hypers do hyper_vals.{i} <- gsl_hypers.{i} done;
+          Hyper.update kernel hyper_vals, sigma2
+        in
+        let multim_f ~x:gsl_hypers =
+          let kernel, sigma2 = update_hypers ~gsl_hypers in
+          let eval_inducing =
+            Eval_inducing.calc kernel eval_inducing_prepared
+          in
+          let eval_inputs =
+            Eval_inputs.calc eval_inducing eval_inputs_prepared
+          in
+          let model = Eval_model.calc eval_inputs ~sigma2 in
+          let trained = Eval_trained.calc model ~targets in
+          let log_evidence = Eval_trained.calc_log_evidence trained in
+          -. log_evidence
+        in
+        let multim_dcommon ~x:gsl_hypers ~g:gradient =
+          let kernel, sigma2 = update_hypers ~gsl_hypers in
+          let deriv_inducing = Inducing.calc kernel deriv_inducing_prepared in
+          let deriv_inputs = Inputs.calc deriv_inducing deriv_inputs_prepared in
+          let dmodel = Cm.calc ~sigma2 deriv_inputs in
+          let trained = Trained.calc dmodel ~targets in
+          let dlog_evidence_dsigma2 =
+            Trained.calc_log_evidence_sigma2 trained
+          in
+          gradient.{0} <- -. dlog_evidence_dsigma2 *. sigma2;
+          let hyper_t = Trained.prepare_hyper trained in
+          for i = 1 to n_hypers do
+            gradient.{i} <-
+              -. Trained.calc_log_evidence hyper_t hyper_vars.(i - 1)
+          done;
+          trained
+        in
+        let multim_df ~x ~g = ignore (multim_dcommon ~x ~g) in
+        let multim_fdf ~x ~g =
+          let trained = multim_dcommon ~x ~g in
+          let log_evidence =
+            Eval_trained.calc_log_evidence (Trained.calc_eval trained)
+          in
+          -. log_evidence
+        in
+        let multim_fun_fdf =
+          {
+            Gsl_fun.
+            multim_f = multim_f;
+            multim_df = multim_df;
+            multim_fdf = multim_fdf;
+          }
+        in
+        let mumin =
+          Gd.make Gd.VECTOR_BFGS2 n_hypers
+            multim_fun_fdf ~x:gsl_hypers ~step:1e-3 ~tol:1e-4
+        in
+        let rec loop last_log_evidence =
+          let neg_log_likelihood = Gd.minimum ~x:gsl_hypers mumin in
+          let log_evidence = -. neg_log_likelihood in
+          let diff = abs_float (1. -. (log_evidence /. last_log_evidence)) in
+          if diff < 0.001 then
+            let kernel, sigma2 = update_hypers ~gsl_hypers in
+            let eval_inducing =
+              Eval_inducing.calc kernel eval_inducing_prepared
+            in
+            let eval_inputs =
+              Eval_inputs.calc eval_inducing eval_inputs_prepared
+            in
+            let model = Eval_model.calc eval_inputs ~sigma2 in
+            let trained = Eval_trained.calc model ~targets in
+            let coeffs = Eval_trained.get_coeffs trained in
+            {
+              Solution.
+              kernel = kernel;
+              sigma2 = sigma2;
+              coeffs = coeffs;
+              log_evidence = log_evidence;
+            }
+          else begin
+            Gd.iterate mumin;
+            loop log_evidence
+          end
+        in
+        loop neg_infinity
+    end
   end
 end
 
   module SPGP = struct
     module Spec = Spec
 
-    let solve ?kernel ?inducing ~inputs ~targets =
-      ignore (kernel, inducing, inputs, targets);
-      (assert false (* XXX *))
+    module Solution = struct
+    end
+
+    let solve ?kernel ?sigma2 ?inducing ?n_inducing ~inputs ~targets =
+      ignore (kernel, sigma2, inducing, n_inducing, inputs, targets);
+      (assert false (* TODO *))
   end
 end

lib/interfaces.ml

         val calc_log_evidence : hyper_t -> Spec.Hyper.t -> float
       end
 
-(*
       module Optim : sig
+        module Solution : sig
+          type t = {
+            kernel : Eval.Spec.Kernel.t;
+            sigma2 : float;
+            coeffs : vec;
+            log_evidence : float;
+          }
+        end
+
         val solve :
           ?kernel : Eval.Spec.Kernel.t ->
+          ?sigma2 : float ->
           ?inducing : Eval.Spec.Inducing.t ->
+          ?n_inducing : int ->
           inputs : Eval.Spec.Inputs.t ->
           targets : vec ->
-          Eval.Spec.Kernel.t * vec
+          Solution.t
       end
-*)
     end
   end
 
         with module Eval = Deriv.Eval.Spec
         with module Deriv = Deriv.Deriv.Spec
 
+      module Solution : sig
+      end
+
       val solve :
-        ?kernel : Spec.Eval.Kernel.t ->
-        ?inducing : Spec.Eval.Inducing.t ->
-        inputs : Spec.Eval.Inputs.t ->
+        ?kernel : Eval.Spec.Kernel.t ->
+        ?sigma2 : float ->
+        ?inducing : Eval.Spec.Inducing.t ->
+        ?n_inducing : int ->
+        inputs : Eval.Spec.Inputs.t ->
         targets : vec ->
-        Spec.Eval.Kernel.t * Spec.Eval.Inducing.t * vec
+        Eval.Spec.Kernel.t * Eval.Spec.Inducing.t * vec
     end
   end
 end

test/find_sigma2.ml

 let find_sigma2 () =
   let module Eval = FITC.Eval in
   let module Deriv = FITC.Deriv in
-  let eval_prep_inducing = Eval.Inducing.Prepared.calc inducing_inputs in
-  let deriv_prep_inducing = Deriv.Inducing.Prepared.calc eval_prep_inducing in
-  let inducing = Deriv.Inducing.calc kernel deriv_prep_inducing in
-  let eval_prep_inputs =
-    Eval.Inputs.Prepared.calc eval_prep_inducing training_inputs
+  let eval_inducing_prep = Eval.Inducing.Prepared.calc inducing_inputs in
+  let deriv_inducing_prep = Deriv.Inducing.Prepared.calc eval_inducing_prep in
+  let inducing = Deriv.Inducing.calc kernel deriv_inducing_prep in
+  let eval_inputs_prep =
+    Eval.Inputs.Prepared.calc eval_inducing_prep training_inputs
   in
-  let deriv_prep_inputs =
-    Deriv.Inputs.Prepared.calc deriv_prep_inducing eval_prep_inputs
+  let deriv_inputs_prep =
+    Deriv.Inputs.Prepared.calc deriv_inducing_prep eval_inputs_prep
   in
-  let inputs = Deriv.Inputs.calc inducing deriv_prep_inputs in
+  let inputs = Deriv.Inputs.calc inducing deriv_inputs_prep in
   let eval_inputs = Deriv.Inputs.calc_eval inputs in
 
   let model_ref = ref None in
     let log_evidence = -. nll in
     let diff = abs_float (1. -. (log_evidence /. last_log_evidence)) in
     printf "diff: %f\n%!" diff;
-    if diff < 0.001 then nll
+    if diff < 0.001 then -. nll, exp x.{0}
     else (
       printf "log evidence: %f\n%!" log_evidence;
       Gd.iterate mumin;
       loop log_evidence)
   in
-  let nll = loop neg_infinity in
-  -. nll, exp x.{0}
+  loop neg_infinity
 
 let main () =
   let log_evidence, sigma2 = find_sigma2 () in
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.