# ocaml-core / base / core / extended / lib / sampler.ml

 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94``` ```(* This algorithm was originally described in the paper "A Linear Algorithm For Generating Random Numbers With a Given Distribution" by Michael Vose. original paper: http://web.eecs.utk.edu/~vose/Publications/random.pdf decent exposition: http://www.keithschwarz.com/darts-dice-coins/ *) open Core.Std type 'a cell = Single of 'a | Branch of float * 'a * 'a type 'a t = 'a cell array let sample ?state t = let module R = Random.State in let s = Option.value state ~default:R.default in let i = R.int s (Array.length t) in match t.(i) with | Single a -> a | Branch (p, a, b) -> if R.float s 1.0 < p then a else b let create dist = let dist = (* the following two steps are fused together into a single normalization phase that is both more efficient and more numerically stable than the two passes done in sequence. 1. we support histograms as inputs by scaling all "probabilities" so that they add up to 1. This is done by dividing each "probability" by the sum of all the "probabilities". 2. the remainder of the algorithm works in terms of scaled probabilities: multiplied by the size of the distribution. *) let w = List.fold ~f:(fun sum (_, p) -> sum +. p) ~init:0. dist in let n = float (List.length dist) in List.map dist ~f:(fun (a, p) -> (a, (p *. n) /. w)) in (* loop invariants: 1. forall [ p >= 1 | (_, p) <- above ] 2. forall [ p < 1 | (_, p) <- below ] 3. length acc + length above + length below = length dist *) let add (below, above) (x, p) = if p < 1. then ((x, p) :: below, above) else (below, (x, p) :: above) in let rec loop acc = function | ((b, pb) :: below, (a, pa) :: above) -> let pa = (pa +. pb) -. 1.0 in loop (Branch (pb, b, a) :: acc) (add (below, above) (a, pa)) | (below, (x, _) :: above) | ((x, _) :: below, above) -> loop (Single x :: acc) (below, above) | ([], []) -> Array.of_list acc in loop [] (List.fold ~f:add ~init:([], []) dist) TEST_MODULE = struct let probs = [ ("A", 0.083); ("B", 0.084); ("C", 0.333); ("D", 0.500); ] let t : string t = create probs let histogram = String.Table.create ~size:5 () let num_samples = 1_000_000 TEST_UNIT = for _i = 1 to num_samples do let key = sample t in incr (Hashtbl.find_or_add histogram key ~default:(fun () -> ref 0)) done let test_outcome key = let prob = List.Assoc.find_exn probs key in let count = !(Hashtbl.find_exn histogram key) in let percentage = float count /. float num_samples in if Float.abs (percentage -. prob) > 0.001 then failwithf "prob = %G; percentage = %G" prob percentage () TEST_UNIT = test_outcome "A" TEST_UNIT = test_outcome "B" TEST_UNIT = test_outcome "C" TEST_UNIT = test_outcome "D" end ```