1. Alexandre Patry
  2. mallet

Commits

Gregory Druck  committed 4e6299f

Minor changes and fixes to cc.mallet.classify.FeatureConstraintUtil.

  • Participants
  • Parent commits 5b24433
  • Branches default

Comments (0)

Files changed (1)

File src/cc/mallet/classify/FeatureConstraintUtil.java

View file
  • Ignore whitespace
     return features;
   }  
   
+  public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features) {
+    return setTargetsUsingData(list,features,true);
+  }
+  
+  public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean normalize) {
+    return setTargetsUsingData(list,features,false,normalize);
+  }
+  
   /**
    * Set target distributions using estimates from data.
    * 
-   * @param list InstanceList used to estimate target distributions.
+   * @param list InstanceList used to estimate targets.
    * @param features List of features for constraints.
-   * @return Constraints (map of feature index to target distribution), with target
-   *         distributions set using estimates from supplied data.
+   * @param normalize Whether to normalize by feature counts
+   * @return Constraints (map of feature index to target), with targets
+   *         set using estimates from supplied data.
    */
-  public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features) {
+  public static HashMap<Integer,double[]> setTargetsUsingData(InstanceList list, ArrayList<Integer> features, boolean useValues, boolean normalize) {
     HashMap<Integer,double[]> constraints = new HashMap<Integer,double[]>();
     
-    double[][] featureLabelCounts = getFeatureLabelCounts(list);
+    double[][] featureLabelCounts = getFeatureLabelCounts(list,useValues);
 
     for (int i = 0; i < features.size(); i++) {
       int fi = features.get(i);
       if (fi != list.getDataAlphabet().size()) {
         double[] prob = featureLabelCounts[fi];
-        // Smooth probability distributions by adding a (very)
-        // small count.  We just need to make sure they aren't
-        // zero in which case the KL-divergence is infinite.
-        MatrixOps.plusEquals(prob, 1e-8);
-        MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
+        if (normalize) {
+          // Smooth probability distributions by adding a (very)
+          // small count.  We just need to make sure they aren't
+          // zero in which case the KL-divergence is infinite.
+          MatrixOps.plusEquals(prob, 1e-8);
+          MatrixOps.timesEquals(prob, 1./MatrixOps.sum(prob));
+        }
         constraints.put(fi, prob);
       }
     }
    * 
    * @param list InstanceList used to compute statistics for labeling features.
    * @param features List of features to label.
+   * @param reject Whether to reject labeling features.
    * @return Labeled features, HashMap mapping feature indices to list of labels.
    */
-  public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features) {
+  public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features, boolean reject) {
     HashMap<Integer,ArrayList<Integer>> labeledFeatures = new HashMap<Integer,ArrayList<Integer>>();
     
-    double[][] featureLabelCounts = getFeatureLabelCounts(list);
+    double[][] featureLabelCounts = getFeatureLabelCounts(list,true);
     
     int numLabels = list.getTargetAlphabet().size();
     
       
       // reject features with infogain
       // less than cutoff
-      if (infogain.value(fi) < mean) {
+      if (reject && infogain.value(fi) < mean) {
+        //System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
         logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
         continue;
       }
           if (prob[li] > threshold) {
             labels.add(li);
           }
-          if (labels.size() > (numLabels / 2)) {
+          if (reject && labels.size() > (numLabels / 2)) {
+            //System.err.println("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
+            logger.info("Oracle labeler rejected labeling: " + list.getDataAlphabet().lookupObject(fi));
             discard = true;
             break;
           }
     return labeledFeatures;
   }
   
-  private static double[][] getFeatureLabelCounts(InstanceList list) {
+  public static HashMap<Integer, ArrayList<Integer>> labelFeatures(InstanceList list, ArrayList<Integer> features) {
+  	return labelFeatures(list,features,true);
+  }
+  
+  private static double[][] getFeatureLabelCounts(InstanceList list, boolean useValues) {
     int numFeatures = list.getDataAlphabet().size();
     int numLabels = list.getTargetAlphabet().size();
     
         double py = instance.getLabeling().value(li);
         for (int loc = 0; loc < featureVector.numLocations(); loc++) {
           int fi = featureVector.indexAtLocation(loc);
-          double val = featureVector.valueAtLocation(loc);
+          double val;
+          if (useValues) {
+            val = featureVector.valueAtLocation(loc);
+          }
+          else {
+            val = 1.0;
+          }
           featureLabelCounts[fi][li] += py * val;
         }
       }