Commits

Elias Ponvert committed c9f8126

fixed prob with punc and sent len; cosmetic changes; more iterators; new CLI option

Comments (0)

Files changed (1)

 
 
 import string
-import itertools
+from itertools import imap, ifilter, chain
 import random
 import math
 from sys import stdout, stderr, exit
     """Iterator over the brackets.
     """
     if unary and (whole or self.length > 1):
-      c1 = itertools.imap(lambda a: (a, a+1), range(self.start_index, self.start_index+self.length))
+      c1 = imap(lambda a: (a, a+1), range(self.start_index, self.start_index+self.length))
     else:
       c1 = []
     if whole and self.length > 1:
     else:
       c3 = []
     
-    return itertools.chain(c1, self.brackets, c3)
+    return chain(c1, self.brackets, c3)
   
   def set_start_index(self, start_index):
     """Change internal representation.
     
     # No da igual si se consideran empty spans:
     #for s in self.S:
-    for s in itertools.ifilter(lambda s: len(s) >= self.min_sentence_length, self.S):
+    for s in ifilter(lambda s: len(s) >= self.min_sentence_length, self.S):
       p_bracket = self.p_bracket(s, init)
       self.MStep_s(s, p_bracket, tita)
     
     #for s in self.S:
     p_true = 0.0
     p_false = 0.0
-    for s in itertools.ifilter(lambda s: len(s) > 1, self.S):
+    for s in ifilter(lambda s: len(s) > 1, self.S):
       p_bracket = self.p_bracket(s, init)
       
       n = len(s)
 PUNC_SYM = ';'
 
 def sentences_n(fname, n=-1):
-  lst = [s.split() for s in open(fname)]
-  if n >= 0:
-    return [s for s in lst if len(s) <= n]
-  else:
-    return lst
+  for s in open(fname):
+    s = s.split()
+    if n < 0 or len(notpunc(s, keep_punc)) <= n:
+      yield s
+
+def notpunc(lst):
+  return [w for w in lst if w != PUNC_SYM]
 
 def bracket_sentences_n(fname, n, punc_pos, keep_punc=[]):
 
 
         open_p -= 1
         if open_p == 0:
-          lst.append(sent)
+          if n < 0 or len(notpunc(sent)) <= n:
+            yield sent
+
           sent = []
     
       else:
         current.append(ch)
 
-  if n >= 0:
-    return [s for s in lst if len(s) <= n]
-
-  else:
-    return lst
-
 def sentences_n_fmt(fname, n=-1, fmt='txt', punc=False):
 
   keep_punc = punc and KEEP_PUNC or []
   
   if fmt == 'txt':
-    return sentences_n(fname, n, keep_punc=keep_punc)
+    return sentences_n(fname, n)
 
   elif fmt == 'wsj':
     return bracket_sentences_n(fname, n, WSJ_PUNC_POS, keep_punc=keep_punc)
 
   return sents_new, stitch
 
-def stitch_sentences(chunks, stitch):
+def stitch_sentences(segments, stitch):
   # convert iterator to list if necessary
-  chunks = list(chunks)
+  segments = list(segments)
   start = 0
-  for len in stitch:
-    end = start+len
-    yield '(' + (' '.join(chunks[start:end])) + ')'
+  for n in stitch:
+    end = start + n
+    chunks = filter(len, segments[start:end])
+    yield '(' + (' '.join(chunks)) + ')'
     start = end
 
 if __name__ == '__main__':
 
   op = OptionParser()
   op.add_option('-t', '--train')
-  op.add_option('-s', '--test')
+  op.add_option('-T', '--test')
   op.add_option('-o', '--output')
   op.add_option('-A', '--output_all', action='store_true')
   op.add_option('-n', '--len_trn', type='int', default=10)
-  op.add_option('-N', '--len_tst', type='int', default=10)
+  op.add_option('-N', '--len_tst', type='int', default=-2)
   op.add_option('-f', '--fmt_trn', default='txt')
-  op.add_option('-F', '--fmt_tst', default='txt')
+  op.add_option('-F', '--fmt_tst')
   op.add_option('-I', '--em_iter', type='int', default=10)
   op.add_option('-2', '--ccm2', action='store_true')
   op.add_option('-p', '--punc', action='store_true')
 
   opt, args = op.parse_args()
 
+  # set defaults for test set based on train set
+  if opt.test:
+    if not(opt.fmt_tst):
+      opt.fmt_tst = opt.fmt_trn
+
+    if opt.len_tst == -2:
+      opt.len_tst = opt.len_trn
+
   kwargs = dict(n=opt.len_trn, fmt=opt.fmt_trn, punc=opt.punc)
-  train_f = sum([glob(t) for t in opt.train.split()], [])
-  train = sum([sentences_n_fmt(f, **kwargs) for f in train_f], [])
+  mksents = lambda f:sentences_n_fmt(f, **kwargs)
+
+  train_f = chain.from_iterable(glob(t) for t in opt.train.split())
+  train = chain.from_iterable(imap(mksents, train_f))
 
   if opt.punc:
     train, stitch = unstitch_sentences(train)
 
+  # future: avoid using lists
+  train = list(train)
+
   if opt.test:
-    kwargs = dict(n=opt.len_tst, fmt=opt.fmt_tst)
-    test_f = sum([glob(t) for t in opt.test.split()], [])
-    test = sum([sentences_n_fmt(f, **kwargs) for f in test_f], [])
+    test_f = chain.from_iterable(glob(t) for t in opt.test.split())
+
+    # check it: the lambda mksents is using the newly defined kwargs
+    kwargs = dict(n=opt.len_tst, fmt=opt.fmt_tst, punc=opt.punc)
+    test = chain.from_iterable(imap(mksents, test_f))
+
+    # future: avoid using lists
+    test = list(test)
+
 
     if opt.punc:
       test, stitch = unstitch_sentences(test)