Commits

Elias Ponvert committed df93d14

support for stdin ; chunked input

Comments (0)

Files changed (1)

 
 
 import string
-from itertools import imap, ifilter, chain
+from itertools import imap, chain, izip
 import random
 import math
-from sys import stdout, stderr, exit
+from sys import stdout, stderr, exit, stdin
 from glob import glob
 
 class ParamDict:
  
   log_fh = stderr
   
-  def __init__(self, corpus):
+  def __init__(self, corpus, brackets):
 
     random.seed(0)
     self.tita = None
     count = ParamDict()
 
     self.S = corpus
+    self.known_brackets = brackets
 
     for s in self.S:
       n = len(s)
     self.Log_likelihood = 0.0
     
     # No da igual si se consideran empty spans:
-    #for s in self.S:
-    for s in ifilter(lambda s: len(s) >= self.min_sentence_length, self.S):
-      p_bracket = self.p_bracket(s, init)
+    for s, br in izip(self.S, self.known_brackets):
+      p_bracket = self.p_bracket(s, init, br)
       self.MStep_s(s, p_bracket, tita)
     
     for (a, b), p in tita.iteritems():
   def EStep(self):
     pass
   
-  
   # el s viene sin START y END...
-  def p_bracket(self, s, init):
+  def p_bracket(self, s, init, known_brackets):
+
     if self.tita is None:
       init = True
+
+    #debug else:
+      #debug print s
+      #debug print known_brackets
     
     n = len(s)
     
     result[0, n] = 1.0
     
     if not init:
-      (I, O) = self.IO(s)
+      (I, O) = self.IO(s, known_brackets)
       self.Log_likelihood += math.log(I[0, n])
     
     # antes l llegaba hasta n pero lo optimice:
           if I[i, j]*O[i, j] == 0.0:
             result[i, j] = 0.0
           else:
-            result[i, j] = I[i, j]*O[i, j]/(I[0, n]*self.phi(i, j, s))
+            denom = I[0, n]*self.phi(i, j, s, known_brackets)
+            if denom == 0.0:
+              result[i,j] = 0.0
+
+            else:
+              result[i, j] = I[i, j]*O[i, j]/denom
         #if not ((0.0 <= result[i, j]) and (result[i, j] <= 1.0001)):
         #  print 'Se pasoo result[i ,j]', result[i, j]
         #  print 's, i, j', s, i, j
     return result
   
   
-  def IO(self, s):
+  def IO(self, s, known_brackets):
     n = len(s)
     
-    I = self.I(s)
+    I = self.I(s, known_brackets)
     
     O = {}
     O[0, n] = 1.0
         j = i + l
         sum1 = sum(I[k, i]*O[k, j] for k in range(i))
         sum2 = sum(I[j, k]*O[i, k] for k in range(j+1, n+1))
-        O[i, j] = self.phi(i, j, s) * (sum1 + sum2)
+        O[i, j] = self.phi(i, j, s, known_brackets) * (sum1 + sum2)
     
     return (I, O)
   
   
   # Ver tesis de Klein, A.1:
-  def I(self, s):
+  def I(self, s, known_brackets):
     I = {}
     
     n = len(s)
         if l == 1:
           I[i, j] = 1.0
         elif l == 2:
-          I[i, j] = self.phi(i, j, s)
+          I[i, j] = self.phi(i, j, s, known_brackets)
         elif l == n:
           I[i, j] = sum(I[i, k] * I[k, j] for k in range(i+1, j))
         else:
-          I[i, j] = self.phi(i, j, s) * sum(I[i, k] * I[k, j] for k in range(i+1, j))
+          I[i, j] = self.phi(i, j, s, known_brackets) * \
+                    sum(I[i, k] * I[k, j] for k in range(i+1, j))
     
     return I
  
-  def parse_batch(self, strings, wrap=True):
-    if wrap:
-      for s in strings:
-        yield '(%s)' % self.parse(s)[0].strfy(s)
-    else:
-      for s in strings:
-        yield self.parse(s)[0].strfy(s)
+  def parse_batch(self, strings, brackets, wrap=True):
+    template = wrap and '(%s)' or '%s'
+    for st, br in izip(strings, brackets):
+      yield template % self.parse(st, br)[0].strfy(st)
   
-  def parse(self, s):
+  def parse(self, s, known_brackets):
+    parse = {}
+
     if s == []:
       return (Bracketing(0, set()), [])
-
-    parse = {}
     
     n = len(s)
     for l in range(1, n+1):
         if l == 1:
           parse[i, j] = (1.0, [])
         elif l == 2:
-          parse[i, j] = (self.phi(i, j, s), [(i, j)])
+          parse[i, j] = (self.phi(i, j, s, known_brackets), [(i, j)])
         else:
           max, k_max = None, []
           for k in range(i+1, j):
           # sin embargo da (casi) el mismo resultado: 68.9
           #k_max = k_max[0]
           
-          parse[i, j] = (max * self.phi(i, j, s), [(i, j)] + parse[i, k_max][1] + parse[k_max, j][1])
-   
+          parse[i, j] = (max * self.phi(i, j, s, known_brackets), \
+                         [(i, j)] + parse[i, k_max][1] + parse[k_max, j][1])
+
     result = (Bracketing(n, set(parse[0, n][1][1:])), parse[0, n][0])
     
     return result
   
   
-  def phi(self, i, j, s):
+  def phi(self, i, j, s, known_brackets):
+
+    #debug print s[i:j],(i,j)
+    #debug print known_brackets
+
     if self.tita is None:
+      #debug print 'tits is none'
       return 1.0
+    
+    if (i, j) in known_brackets:
+      #debug print 'in brackets'
+      return 1.0
+
+    for _i, _j in known_brackets:
+      if _i < i <= _j < j or i < _i < j < _j:
+        #debug print (i,j), 'crosses brackets', (_i,_j)
+        return 0.0
+
+    #debug print 'contingent'
+
     n = len(s)
     s = s + ['END', 'START']
     tita = self.tita
     alpha = string.join(s[i:j])
     beta = (s[i-1], s[j])
     # j -i == 1 y j - i == len(s) podrian estar fuera del dominio de phi tambien.
+
     if j - i == 1:
       # assert False
       result = tita.val((beta, True))
+      
     elif j - i == n:
       # se invoca en parse() y en IO()/I() creo que tambien
       # assert False
       result = tita.val((alpha, True))
+
     else:
       if tita.val((alpha, False)) * tita.val((beta, False)) == 0:
         print >>self.log_fh, 'i j s', i, j, s
     
     return result
 
-
 def p_bracket_split(i, j, n):
   if i == j:
     res = 0.0
   
   def MStep(self, init=False):
     tita = CCMTita()
-    # Da igual de las dos maneras:
-    #for s in self.S:
     p_true = 0.0
     p_false = 0.0
-    for s in ifilter(lambda s: len(s) > 1, self.S):
-      p_bracket = self.p_bracket(s, init)
+    for s, br in izip(self.S, self.known_brackets):
+      p_bracket = self.p_bracket(s, init, br)
       
       n = len(s)
       s = s + ['END', 'START']
 KEEP_PUNC = [',', '.', ';', '--']
 PUNC_SYM = ';'
 
+def sentences_n_bracks(fname, n=-1, fmt='chunked', punc=False):
+  sentences = []
+  brackets = []
+
+  if type(fname) == str:
+    fh = open(fname)
+
+  else:
+    fh = fname
+
+  for s in fh:
+    s = s.replace('(','( ').replace(')',' )')
+    s = s.split()
+    w_index = 0
+    chunk_st = -1
+    new_s = []
+    new_br = []
+    for t in s:
+      if t == PUNC_SYM:
+        assert chunk_st < 0
+        new_s.append(t)
+
+      elif t == '(':
+        assert chunk_st < 0
+        chunk_st = w_index
+
+      elif t == ')':
+        assert chunk_st >= 0
+        new_br.append((chunk_st, w_index))
+        chunk_st = -1
+
+      else:
+        new_s.append(t)
+        w_index += 1
+
+    if n < 0 or w_index <= n:
+      sentences.append(new_s)
+      brackets.append(new_br)
+
+  if type(fname) == str:
+    fh.close()
+
+  return sentences, brackets
+
 def sentences_n(fname, n=-1):
-  for s in open(fname):
+  
+  if type(fname) == str:
+    fh = open(fname)
+
+  else:
+    fh = fname
+
+  for s in fh:
     s = s.split()
     if n < 0 or len(notpunc(s)) <= n:
       yield s
 
+  if type(fname) == str:
+    fh.close()
+
 def notpunc(lst):
   return [w for w in lst if w != PUNC_SYM]
 
   lst = []
   sent = []
 
-  for line in open(fname):
+  if type(fname) == str:
+    fh = open(fname)
+
+  else:
+    fh = fname
+
+  for line in fh:
     for ch in line:
       if ch == '(':
         current = []
       else:
         current.append(ch)
 
+  if type(fname) == 'str':
+    fh.close()
+    
+
 def sentences_n_fmt(fname, n=-1, fmt='txt', punc=False):
 
   keep_punc = punc and KEEP_PUNC or []
     yield '(' + (' '.join(chunks)) + ')'
     start = end
 
+def files_from_glob(g):
+  train_f = chain.from_iterable(glob(t) for t in g.split())
+  train_f = list(train_f)
+  train_f.sort()
+  return train_f
+
 if __name__ == '__main__':
  
   from optparse import OptionParser
 
   op = OptionParser()
-  op.add_option('-t', '--train')
+  op.add_option('-t', '--train', default='-')
   op.add_option('-T', '--test')
   op.add_option('-o', '--output')
   op.add_option('-A', '--output_all', action='store_true')
 
   kwargs = dict(n=opt.len_trn, fmt=opt.fmt_trn, punc=opt.punc)
   mksents = lambda f:sentences_n_fmt(f, **kwargs)
+  mksents_bracks = lambda f:sentences_n_bracks(f, **kwargs)
 
-  train_f = chain.from_iterable(glob(t) for t in opt.train.split())
-  train = chain.from_iterable(imap(mksents, train_f))
+  if opt.fmt_trn == 'chunked':
+    
+    if opt.train == '-':
+      train, train_brack = mksents_bracks(stdin)
+
+    else:
+      train_f = files_from_glob(opt.train)
+      # TODO limitation, only one file allowed for now
+      train, train_brack = mksents_bracks(train_f[0])
+
+  else:
+
+    if opt.train == '-':
+      train = mksents(stdin)
+
+    else:
+      train_f = files_from_glob(opt.train)
+      train = chain.from_iterable(imap(mksents, train_f))
+
+    # future: avoid using lists
+    train = list(train)
+    train_brack = [[] for _ in train]
 
   if opt.punc:
     train, stitch = unstitch_sentences(train)
 
-  # future: avoid using lists
-  train = list(train)
+  if opt.test:
+    kwargs = dict(n=opt.len_tst, fmt=opt.fmt_tst, punc=opt.punc)
 
-  if opt.test:
-    test_f = chain.from_iterable(glob(t) for t in opt.test.split())
+    if opt.fmt_tst == 'chunked':
+      if opt.test == '-':
+        # check it: the lambda mksents_bracks is using the newly defined kwargs
+        test, test_brack = mksents_bracks(stdin)
 
-    # check it: the lambda mksents is using the newly defined kwargs
-    kwargs = dict(n=opt.len_tst, fmt=opt.fmt_tst, punc=opt.punc)
-    test = chain.from_iterable(imap(mksents, test_f))
+      else:
+        test_f = files_from_glob(opt.test)
+        # TODO limitation, only one file allowed for now
+        test, test_brack = mksents_bracks(test_f[0])
 
-    # future: avoid using lists
-    test = list(test)
+    else:
+      test_f = files_from_glob(opt.test)
 
+      # check it: the lambda mksents is using the newly defined kwargs
+      test = chain.from_iterable(imap(mksents, test_f))
+
+      # future: avoid using lists
+      test = list(test)
+      test_brack = [[] for _ in test]
 
     if opt.punc:
-      test, stitch = unstitch_sentences(test)
+       test, stitch = unstitch_sentences(test)
 
   else:
     test = train
+    test_brack = train_brack
 
   log_fh = stderr
 
   print >>log_fh, 'train corpus:', len(train), 
   print >>log_fh, 'sentences,', sum(map(len, train)), 'tokens'
 
-  ccm = (opt.ccm2 and CCM2 or CCM)(train)
+  ccm = (opt.ccm2 and CCM2 or CCM)(train, train_brack)
   for it in range(opt.em_iter):
     print >>log_fh, '%2d' % it,
     ccm.step()
     if opt.output and opt.output_all:
       out_fh = open('%s.%02d' % (opt.output, it), 'w')
 
-      if opt.punc:
-        for s in stitch_sentences(ccm.parse_batch(test, False), stitch):
-          print >>out_fh, s
-
-      else:
-        for s in ccm.parse_batch(test, True):
-          print >>out_fh, s
+      segments = ccm.parse_batch(test, test_brack, not opt.punc)
+      it = opt.punc and stitch_sentences(segments, stitch) or segments
+      for s in it:
+        print >>out_fh, s
 
       out_fh.close()
   
 
   out_fh = opt.output and open(opt.output, 'w') or stdout
 
-  if opt.punc:
-    for s in stitch_sentences(ccm.parse_batch(test, False), stitch):
-      print >>out_fh, s
-
-  else:
-    for s in ccm.parse_batch(test, True):
-      print >>out_fh, s
+  segments = ccm.parse_batch(test, test_brack, not opt.punc)
+  it = opt.punc and stitch_sentences(segments, stitch) or segments
+  for s in it:
+    print >>out_fh, s
 
   if opt.output is not None: 
     out_fh.close()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.