Commits

abhimanu  committed c8edf98

span prediction on wiki-bio

  • Participants
  • Parent commits c9f772b

Comments (0)

Files changed (5)

File .cache

Binary file modified.

File src/main/scala/opennlp/textgrounder/app/RunApps.scala

     
     if(args.length==0){
     println("default args")
-    val args = new Array[String](22)
+    val args = new Array[String](26)
     args(0) = "--input-corpus"
 //    args(1) = "/home/abhimanu/textgrounder_temporal/data/corpora/temporal/docthresh-5"
-    args(1) = "/home/abhimanu/datasets/textgrounder/temporal/gutts/" //wiki-bio gutts wiki-years
+    args(1) = "/home/abhimanu/datasets/textgrounder/temporal/wiki-bio/" //wiki-bio gutts wiki-years
     args(2) = "--width-of-multi-cell"
-    args(3) = "40"
+    args(3) = "10"
     args(4) = "--eval-set"
     args(5) = "dev"
     args(6) = "--word-dist"
     //NOTE: Naive Bayes has been currently coded for only Dirichlet/JM Smoothing
     //By default Bayes is chronon-docs
     args(12) = "--smoothing-par"
-    args(13) = "0.99"					//for JM=0.99  and cikm = 0.01
+    args(13) = "0.999"					//for JM=0.99  and cikm = 0.01
     args(14) = "--bayes-prior"
     args(15) = "chronon-docs"  					//"uniform" "chronon-docs" ; default=chronon-docs
     args(16) = "--smoothing-type"  
     args(19) = "3"
     args(20) = "--error-limit"
     args(21) = "107"
+    args(22) = "--span-pred"
+    args(23) = "variance"
+    args(24) = "--span-parameter"
+    args(25) = "0.3"
     
       //NOTE: always remember to set "temporal-dirichlet" before doing bayes
     TemporalDocumentApp.main(args)

File src/main/scala/opennlp/textgrounder/geolocate/Geolocate.scala

       help = """Chooses Bayes Prior. Options are "uniform" and "chronon-docs".
 Default '%default'.""")
 
+  var span_pred =
+    ap.option[String]("span-pred", "span", metavar = "span_pred",
+      default = "none",
+      choices = Seq("none", "variance", "trimming"),
+      help = """Perform span prediction.""")
+
+  var span_parameter =
+    ap.option[Double]("span-parameter", "span-para", metavar = "NUM",
+      default = 0.6,
+      help = """The span parameter (kappa or gamma).""")
+
+  var span_big_width =
+    ap.option[Int]("span-big-width", "span-big", metavar = "NUM",
+      default = 100,
+      help = """The span parameter (big delta).""")
+      
   var analysis_word_document_frequency_limit =
     ap.option[Int]("frequency-limit", "freq-limit", metavar = "NUM",
       default = 1,

File src/main/scala/opennlp/textgrounder/temporal/SimpleTimeCell.scala

       record_created_cell = false)		//FIXME should I create this if not found?
   }
 
+  def getSpan(pred_cells: Array[(TemporalCell, Double)], spanType: String, parameter: Double, 
+      delta: Int, bigDelta: Int) = {
+//    mutable.Buffer
+//    val numChronons = bigDelta/delta + (if((bigDelta%delta) >0) 1 else 0)
+//    val numLeftChronon = (numChronons-1)/2
+//    val numRightChronon = (if(numChronons%2>0) numLeftChronon+1 else numLeftChronon)
+    val midChronon = pred_cells(0)._1
+    val midChrononMid = (midChronon.get_center_coord().left + midChronon.get_center_coord().right)/2
+//    val leftArray = mutable.Buffer[(TemporalCell, Double)]()
+//    val rightArray = mutable.Buffer[(TemporalCell, Double)]()
+    val mean = pred_cells(0)._2
+    var leftVariance=0.0
+    var rightVariance=0.0
+    var leftChronon = 0
+    var rightChronon = 0
+    var valueSum = 0.0
+    for ((cell,value) <- pred_cells){
+      val cellMid = (cell.get_center_coord().left + cell.get_center_coord().right)/2
+      if (cellMid<midChrononMid && cellMid>= midChrononMid-bigDelta/2){ // left variance
+        leftVariance += (mean-value)*(mean-value)
+        leftChronon+=1
+      }
+      if (cellMid>midChrononMid && cellMid <= midChrononMid+bigDelta/2){ // right variance
+        rightVariance += (mean-value)*(mean-value)
+        rightChronon+=1
+      }
+      valueSum+=value
+    }
+    leftVariance = math.sqrt(leftVariance/(valueSum*valueSum*leftChronon))	// divide by valueSum to normalize
+    rightVariance = math.sqrt(rightVariance/(valueSum*valueSum*rightChronon))
+    val leftPredicted = midChrononMid-parameter*leftVariance
+    val rightPredicted = midChrononMid+parameter*rightVariance
+    getPrecAndRecall(leftPredicted.intValue(), rightPredicted.intValue(), midChronon.get_center_coord().left, 
+        midChronon.get_center_coord().right)
+  }
+  
+  def getPrecAndRecall(leftPredicted: Int, rightPredicted: Int, leftTrue: Int, rightTrue: Int):(Double, Double) = {
+    val IBleft = leftPredicted;
+    val IBright = rightPredicted;
+    val IGleft = leftTrue;
+    val IGright = rightTrue;
+    if (IBright <= IGleft)
+    	return (0.0,0.0)
+    if (IBleft >= IGright)
+    	return (0.0,0.0)
+    if (IBright == IBleft)
+    	return (0.0,0.0)
+    var overlapLeft = 0;
+    var overlapRight = 0;
+    if ((IBleft >= IGleft && IBleft <= IGright) || (IBright >= IGleft && IBright <= IGright) 
+    		|| (IBleft <= IGleft && IBright >= IGright) ){
+
+    	overlapLeft = if(IGleft>IBleft) IGleft else IBleft 	//max(IGleft, IBleft);
+    	overlapRight = if(IGright<IBright) IGright else IBright	//min(IGright, IBright);
+    }
+    val precision = (overlapRight - overlapLeft) * 1.0 / (IBright - IBleft);
+    val recall = (overlapRight - overlapLeft) * 1.0 / (IGright - IGleft);
+    return (precision, recall)
+  }
+  
   //FIXME need to change this ?
   protected def find_cell_for_cell_index(index: RegularCellIndex,
       create: Boolean, record_created_cell: Boolean) = {

File src/main/scala/opennlp/textgrounder/temporal/TemporalEvaluation.scala

   driver_stats, prefix, max_rank_for_credit) {
   val degree_dists = mutable.Buffer[Double]()
   val oracle_degree_dists = mutable.Buffer[Double]()
+  val span_precision = mutable.Buffer[Double]()
+  val span_recall = mutable.Buffer[Double]()
+  val span_fScore = mutable.Buffer[Double]()
   val error_histograms = mutable.HashMap[Double,(Int, Double,String)]()
   val word_ErrorHistogram = mutable.HashMap[String,(Int, Double, Double)]()
   val worst_docs = mutable.HashMap[TemporalDocument,(Double,Int,TemporalCell,TemporalCell)]()
     degree_dists += pred_degree_dist
   }
   
+  def record_prec_recall(precision: Double, recall: Double){
+    span_precision += precision
+    span_recall += recall
+    if (precision!=0 && recall!=0){
+      val fScore = 2*precision*recall/(precision+recall)
+      span_fScore += fScore
+    }
+  }
+  
   def record_histogram(true_center: TemporalCoord, pred_degdist: Double){
     val key = (true_center.left+true_center.right)*1.0/2
     var value=error_histograms.get(key)
       mean(degree_dists))
     errprint("  Median error distance = %.2f Years",
       median(degree_dists))
-    if(all_results_flag){
-//      val temp_map = mutable.ArrayBuffer[(String, Double,Int)]()
-      println("\n\n\n\n\t\t===================Analysis Stats==================\n")
-      println("\t\t===================Error Histogram==================\n")
-    	for(i <- error_histograms.keySet.toList.sorted){
-    	  val value = error_histograms.get(i)
-    	  println(value.get._3+"\t",value.get._2/value.get._1,value.get._1)
-    	}
-      println("\t\t===================Word Errors==================\n")	
-      for(value <- word_ErrorHistogram.toList sortBy {-_._2._3}){
-        if(value._2._1>limit_flag)
-    		println(value._1+"\t",value._2._2/value._2._1, value._2._3,value._2._1)
-    	}
-      println("\t\t===================Worst Docs==================\n")
-      for(value <- worst_docs.toList sortBy {-_._2._1}){
-        if(value._2._1>100)
-    		println(value._1,value._2._1, value._2._2," True Cell: "+value._2._3," Pred Cell"+value._2._4)
-    	}
-    }
+    
+    println("  Precision = %.2f Years",
+      mean(span_precision))
+    println("  Recall = %.2f Years",
+      mean(span_recall))
+    println("  Fscore = %.2f Years",
+      mean(span_fScore))
+//    if(all_results_flag){
+////      val temp_map = mutable.ArrayBuffer[(String, Double,Int)]()
+//      println("\n\n\n\n\t\t===================Analysis Stats==================\n")
+//      println("\t\t===================Error Histogram==================\n")
+//    	for(i <- error_histograms.keySet.toList.sorted){
+//    	  val value = error_histograms.get(i)
+//    	  println(value.get._3+"\t",value.get._2/value.get._1,value.get._1)
+//    	}
+//      println("\t\t===================Word Errors==================\n")	
+//      for(value <- word_ErrorHistogram.toList sortBy {-_._2._3}){
+//        if(value._2._1>limit_flag)
+//    		println(value._1+"\t",value._2._2/value._2._1, value._2._3,value._2._1)
+//    	}
+//      println("\t\t===================Worst Docs==================\n")
+//      for(value <- worst_docs.toList sortBy {-_._2._1}){
+//        if(value._2._1>100)
+//    		println(value._1,value._2._1, value._2._2," True Cell: "+value._2._3," Pred Cell"+value._2._4)
+//    	}
+//    }
 //    errprint("  Median oracle true error distance = %s",
 //      km_and_miles(median(oracle_true_dists)))
   }
     new DoubleTableByRange(dist_fractions_for_error_dist,
       create_stats_for_range("degree_dist_to_pred_center", _))
 
+  def record_prec_recall(precision: Double, recall: Double){
+    all_document.record_prec_recall(precision, recall)
+  }
+  
   override def record_one_result(stats: TBasicEvalStats,
       res: TDocEvalRes) {
     val parameters = cell_grid.table.driver.params
       }
     val result =
       new TemporalDocumentEvaluationResult(document, pred_cells(0)._1, true_rank)
-
+    
+    if (driver.params.span_pred!="none"){
+      val (precision, recall)=strategy.cell_grid.asInstanceOf[SimpleTimeCellGrid].getSpan(pred_cells, 
+          driver.params.span_pred, driver.params.span_parameter, driver.params.width_of_multi_cell, 
+          driver.params.span_big_width);
+      evalstats.asInstanceOf[TemporalGroupedDocumentEvalStats].record_prec_recall(precision, recall);
+    }
+    
     if(document.title=="364810Rolf Dieter Brinkmann"){
     if (debug("all-scores")) {
       for (((cell, value), index) <- pred_cells.zipWithIndex) {