Commits

Stephen Roller committed 6ed2f8c

Add a kd-split-method command line option.

Comments (0)

Files changed (3)

src/main/java/ags/utils/KdTree.java

  * @author Rednaxela
  */
 public class KdTree<T> {
+    // split method enum
+    public enum SplitMethod { HALFWAY, MEDIAN, MAX_MARGIN }
+
     // All types
     private final int                  dimensions;
     private final KdTree<T>            parent;
     private int                        bucketSize;
+    private SplitMethod                splitMethod;
  
     // Leaf only
     private double[][]                 locations;
      * Construct a KdTree with a given number of dimensions and a limit on
      * maxiumum size (after which it throws away old points)
      */
-    public KdTree(int dimensions, int bucketSize) {
+    public KdTree(int dimensions, int bucketSize, SplitMethod splitMethod) {
         this.bucketSize = bucketSize;
         this.dimensions = dimensions;
+        this.splitMethod = splitMethod;
  
         // Init as leaf
         this.locations = new double[bucketSize][];
     private KdTree(KdTree<T> parent, boolean right) {
         this.dimensions = parent.dimensions;
         this.bucketSize = parent.bucketSize;
+        this.splitMethod = parent.splitMethod;
  
         // Init as leaf
         this.locations = new double[Math.max(bucketSize, parent.locationCount)][];
         while (cursor.locations == null || cursor.locationCount >= cursor.locations.length) {
             if (cursor.locations != null) {
                 cursor.splitDimension = cursor.findWidestAxis();
-                //cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + cursor.maxLimit[cursor.splitDimension]) * 0.5;
 
-                List<Double> list = new ArrayList<Double>();
-                for(int i=0;i<cursor.locations.length;i++) {
-                    list.add(cursor.locations[i][cursor.splitDimension]);
-                }
-                Collections.sort(list);
-                if(list.size()%2 == 1) {
-                    cursor.splitValue = list.get(list.size()/2);
-                } else {
-                    cursor.splitValue = (list.get(list.size()/2) + list.get(list.size()/2 - 1))/2;
+                if (splitMethod == SplitMethod.HALFWAY) {
+                    cursor.splitValue = (cursor.minLimit[cursor.splitDimension] + 
+                                         cursor.maxLimit[cursor.splitDimension]) * 0.5;
+                } else if (splitMethod == SplitMethod.MEDIAN) {
+                    // split on the median of the elements
+                    List<Double> list = new ArrayList<Double>();
+                    for(int i = 0; i < cursor.locations.length; i++) {
+                        list.add(cursor.locations[i][cursor.splitDimension]);
+                    }
+                    Collections.sort(list);
+                    if(list.size() % 2 == 1) {
+                        cursor.splitValue = list.get(list.size() / 2);
+                    } else {
+                        cursor.splitValue = (list.get(list.size() / 2) + list.get(list.size() / 2 - 1))/2;
+                    }
+                } else if (splitMethod == SplitMethod.MAX_MARGIN) {
+                    List<Double> list = new ArrayList<Double>();
+                    for(int i = 0; i < cursor.locations.length; i++) {
+                        list.add(cursor.locations[i][cursor.splitDimension]);
+                    }
+                    Collections.sort(list);
+                    double maxMargin = 0.0;
+                    double splitValue = Double.NaN;
+                    for (int i = 0; i < list.size() - 1; i++) {
+                        double delta = list.get(i+1) - list.get(i);
+                        if (delta > maxMargin) {
+                            maxMargin = delta;
+                            splitValue = list.get(i) + 0.5 * delta;
+                        }
+                    }
+                    cursor.splitValue = splitValue;
                 }
 
                 // Never split on infinity or NaN
     }
 
     public static void main(String[] args) {
-        KdTree<String> tree = new KdTree<String>(2, 2);
+        KdTree<String> tree = new KdTree<String>(2, 2, SplitMethod.HALFWAY);
         tree.addPoint(new double[] { 1.0, 1.0 }, "hello1");
         tree.addPoint(new double[] { 10.0, 2.0 }, "world2");
         tree.addPoint(new double[] { 3.0, 4.0 }, "earth3");

src/main/scala/opennlp/textgrounder/geolocate/Geolocate.scala

 center calculation. Options are either 'centroid' or 'center'.
 Default '%default'.""")
 
+  var kd_split_method =
+    ap.option[String]("kd-split-method", "kdsm", metavar = "SPLIT_METHOD",
+      default = "halfway",
+      choices = Seq("halfway", "median", "maxmargin"),
+      help = """Chooses which leaf-splitting method to use. Valid options are
+'halfway', which splits into two leaves of equal degrees, 'median', which
+splits leaves to have an equal number of documents, and 'maxmargin',
+which splits at the maximum margin between two points. All splits are always
+on the longest dimension. Default '%default'.""")
+
+
 
   //// Options used when creating word distributions
   var word_dist =
 
   protected def initialize_cell_grid(table: DistDocumentTable) = {
     if (params.use_kd_tree)
-      new KdTreeCellGrid(table, params.kd_bucketsize)
+      KdTreeCellGrid(table, params.kd_bucketsize, params.kd_split_method)
     else
       new MultiRegularCellGrid(degrees_per_cell,
         params.width_of_multi_cell, table)

src/main/scala/opennlp/textgrounder/geolocate/KDTreeCellGrid.scala

   }
 }
 
-class KdTreeCellGrid(table: DistDocumentTable, bucketSize: Int)
+object KdTreeCellGrid {
+  def apply(table: DistDocumentTable, bucketSize: Int, splitMethod: String) : KdTreeCellGrid = {
+    new KdTreeCellGrid(table, bucketSize, splitMethod match {
+      case "halfway" => KdTree.SplitMethod.HALFWAY
+      case "median" => KdTree.SplitMethod.MEDIAN
+      case "maxmargin" => KdTree.SplitMethod.MAX_MARGIN
+    })
+  }
+}
+
+class KdTreeCellGrid(table: DistDocumentTable, bucketSize: Int, splitMethod: KdTree.SplitMethod)
     extends CellGrid(table) {
   /**
    * Total number of cells in the grid.
    */
   var total_num_cells: Int = 0
-  var kdtree : KdTree[DistDocument] = new KdTree[DistDocument](2, bucketSize);
+  var kdtree : KdTree[DistDocument] = new KdTree[DistDocument](2, bucketSize, splitMethod);
   val leaves_to_cell : Map[KdTree[DistDocument], KdTreeCell] = Map();
 
   /**