Commits

vvcephei committed 54942f9

added tokenchecker and make Merge also check tokens. I kind of hate this change. When I have some time, I need to come back and refactor this.

Comments (0)

Files changed (4)

     kin|mlg)
         echo "merging $1"
         treesrc=$root/$1/tree/src
-        treedst=$root/$1/tree
         for collection in $treesrc/* ; do
             echo $collection
             for tree in $collection/* ; do
                     collection=$( basename $( dirname $fullpath ) )
                     lang=$( basename $( dirname $( dirname $( dirname $( dirname $fullpath ) ) ) ) )
                     echo "merging $fullpath to $root/$lang/tree/$collection/$base.tree"
-                    #echo "running: scalabha run opennlp.scalabha.tree.Merge -i $fullpath -o $root/$lang/tree/$collection/$base.tree"
+                    outputpath=$root/$lang/tree/$collection/$base.tree
+                    tokpath=$root/$lang/tok/$collection/$base.tok
                     if [[ $2 == "-f" ]]; then
-                      scalabha run opennlp.scalabha.tree.Merge -f -i $fullpath -o $root/$lang/tree/$collection/$base.tree
+                      scalabha run opennlp.scalabha.tree.Merge -f -i $fullpath -o $outputpath --tok $tokpath
                     else
-                      scalabha run opennlp.scalabha.tree.Merge --pprintErrs -i $fullpath -o $root/$lang/tree/$collection/$base.tree
+                      scalabha run opennlp.scalabha.tree.Merge --pprintErrs -i $fullpath -o $outputpath --tok $tokpath
                     fi
                     (( exit_code += $? ))
                 else

src/main/scala/opennlp/scalabha/tree/Merge.scala

   val skipErrs = parser.flag[Boolean](List("f", "skipErrs"), "Do not exit on errors. " +
     "The default is to exit as soon as errors are caught in any input file.")
   val pprintErrs = parser.flag[Boolean](List("pprintErrs"), "Format treenodes nicely in error reporting.")
+  val tokenFileOpt = parser.option[String](List("tok"), "FILE", "Optional. If present, the file that contains the tokens to " +
+    "check the tree against.")
+  val TreeFileName = """([^.]*).([^.]*).([^.]*).tree""".r
+  var tokensFromFile: List[List[String]] = Nil
 
   var log = new SimpleLogger(this.getClass().getName, SimpleLogger.WARN, new BufferedWriter(new OutputStreamWriter(System.err)))
 
 
   def applyFile(file: File): List[TreeNode] = {
-    if (file.getName.endsWith("tree"))
-      MultiLineTreeParser(file.getName, scala.io.Source.fromFile(file, "UTF-8"))
-    else {
+    if (file.getName.endsWith("tree")) {
+      val trees: List[TreeNode] = MultiLineTreeParser(file.getName, scala.io.Source.fromFile(file, "UTF-8"))
+      if (tokensFromFile.length != 0) {
+        val TreeFileName(file_id, lang, treeNum) = file.getName
+        val treeIndexBase = treeNum.toInt - 1
+
+        for ((tree, index) <- trees.zipWithIndex) {
+          val treeTokens = tree.getTokens()
+          val tokens = tokensFromFile(treeIndexBase + index)
+          val pass = TokenChecker.checkTokensInLine(treeTokens, tokens)
+        }
+      }
+      trees
+    } else {
       log.warn("Assuming that %s is not a tree file because it does not end with the extension '.tree'. Skipping...\n".format(file.getName))
       List[TreeNode]()
     }
         parser.usage()
       }
 
+      tokensFromFile = tokenFileOpt.value match {
+        case Some(filename) =>
+          scala.io.Source.fromFile(filename, "UTF-8").getLines.map(line => line.replace("<EOS>", "").split("\\s+").toList).toList
+        case _ =>
+          Nil
+      }
+
       MultiLineTreeParser.log.logLevel = SimpleLogger.WARN
       MultiLineTreeParser.pprintErrs = pprintErrs.value.isDefined
 
 
       val (compileWarnings, compileErrors) = log.getStats()
       val (parseWarnings, parseErrors) = MultiLineTreeParser.log.getStats()
-      val (warnings, errors) = (compileWarnings + parseWarnings, compileErrors + parseErrors)
+      val (tokenWarnings, tokenErrors) = TokenChecker.log.getStats()
+      val (warnings, errors) = (
+        compileWarnings + parseWarnings + tokenWarnings,
+        compileErrors + parseErrors + tokenErrors)
 
       log.summary("Warnings,Errors: %s\n".format((warnings, errors)))
       if (errors == 0 || skipErrs.value.isDefined) {

src/main/scala/opennlp/scalabha/tree/TagChecker.scala

 
   def checkTokensInLine(aList: List[String], bList: List[String]): String = {
     if (aList.length != bList.length) {
-      //log.err("Lists should be the same length: %s %s\n".format(aList, bList))
+      //log.err("Lists should be the same length: %s %s\n".format(treeTokens, tokFileTokens))
       "Fail: \n\ttree: %s is not the same length as \n\ttok:  %s".format(aList, bList)
     } else if (aList.length == 0) {
       ""

src/main/scala/opennlp/scalabha/tree/TokenChecker.scala

+package opennlp.scalabha.tree
+
+import opennlp.scalabha.log.SimpleLogger
+import java.io.{OutputStreamWriter, BufferedWriter}
+import org.clapper.argot.{ArgotUsageException, ArgotParser, ArgotConverters}
+import opennlp.scalabha.model.TreeNode
+import collection.mutable.HashMap
+
+object TokenChecker {
+
+  import ArgotConverters._
+  val parser = new ArgotParser(this.getClass.getName, preUsage = Some("Version 0.0"))
+  val help = parser.flag[Boolean](List("h", "help"), "print help")
+  val input = parser.option[String](List("i", "input"), "FILE", "input inputFile in which to check tokens")
+  val tokens = parser.option[String](List("tok"), "FILE", "tokens to check")
+  val silent = parser.flag[Boolean](List("s"), "Set this flag to silence warnings and errors in the tree parser.")
+  val log = new SimpleLogger(this.getClass.getName, SimpleLogger.WARN, new BufferedWriter(new OutputStreamWriter(System.err)))
+
+
+
+  def spprintRepr(map: Map[String, Int], join: String): String = {
+    val regex = "[^(]+\\((.*)\\)".r
+    val regex(string) = map.toList.sorted.toString
+    string.replace(", ", join)
+  }
+
+  def checkTokensInLine(treeTokens: List[String], tokFileTokens: List[String]): Boolean = {
+    if (treeTokens.length != tokFileTokens.length) {
+      //log.err("Lists should be the same length: %s %s\n".format(treeTokens, tokFileTokens))
+      log.err("Fail: \n\ttree: %s is not the same length as \n\ttok:  %s\n".format(treeTokens, tokFileTokens))
+      false
+    } else if (treeTokens.length == 0) {
+      true
+    } else {
+      val a :: as = treeTokens
+      val b :: bs = tokFileTokens
+      if (a != b) {
+        if ((a == "-LRB-" && b == "(")||(b == "-LRB-" && a == "(")) {
+          checkTokensInLine(as, bs)
+        } else if ((a == "-RRB-" && b == ")") || (b == "-RRB-" && a == ")")) {
+          checkTokensInLine(as, bs)
+        } else {
+          //log.err("%s does not match %s\n".format(a, b))
+          log.err(("Fail: \"%s\" does not match \"%s\" in:" +
+            "\n\ttree:%s\n\t tok:%s\n").format(a, b, treeTokens, tokFileTokens))
+          false
+        }
+      } else {
+        checkTokensInLine(as, bs)
+      }
+    }
+  }
+
+  def checkTokens(infile: Iterator[String], tokfile: Iterator[String]): List[String] = {
+    for (((inTreeLine, tokLine), index) <- (infile zip tokfile).toList.zipWithIndex) yield {
+      val inTree = MultiLineTreeParser("trees",index,inTreeLine)
+      inTree match {
+        case Some(root) =>
+          val inTreeTokens: List[String] = root.getTokens
+          val tokTokens = tokLine.replace("<EOS>", "").split("\\s+").toList
+          checkTokensInLine(inTreeTokens, tokTokens) match {
+            case true => "%d: pass".format(index)
+            case false => "%d: fail".format(index)
+          }
+        case _ => "%d: Fail - Couldn't parse tree. See parser log messages.".format(index)
+      }
+    }
+
+
+  }
+
+  def main(args: Array[String]) {
+
+
+    try {
+      parser.parse(args)
+
+      if (help.value.isDefined) {
+        parser.usage()
+      }
+      MultiLineTreeParser.log.logLevel = silent.value match {
+        case Some(_) => SimpleLogger.NONE
+        case _ => MultiLineTreeParser.log.logLevel
+      }
+
+      val input_file = input.value match {
+        case Some(filename:String) => scala.io.Source.fromFile(filename, "UTF-8").getLines()
+        case _ => parser.usage()
+      }
+      val token_file = tokens.value match {
+        case Some(filename:String) => scala.io.Source.fromFile(filename, "UTF-8").getLines()
+        case _ => parser.usage()
+      }
+
+      log.trace("comparing tokens from %s to those in the trees in %s\n".format(tokens.value.get, input.value.get))
+
+      println(checkTokens(input_file, token_file).mkString("\n"))
+
+      val (theseWarnings, theseErrors) = log.getStats()
+      val (parseWarnings, parseErrors) = MultiLineTreeParser.log.getStats()
+      val (warnings, errors) = (theseWarnings + parseWarnings, theseErrors + parseErrors)
+
+      log.summary("Warnings,Errors: %s\n".format((warnings, errors)))
+
+    } catch {
+      case e: ArgotUsageException =>
+        println(e.message)
+    }
+  }
+
+}