Commits

Elias Ponvert committed f5ef883

wsj support

Comments (0)

Files changed (1)

 import itertools
 import random
 import math
-from sys import stdout, stderr
+from sys import stdout, stderr, exit
+from glob import glob
 
 class ParamDict:
   
     
     self.tita = tita
 
-def filtern(n, lst):
-  return [s for s in lst if len(s) <= n]
-
 def sentences_n(fname, n=-1):
   lst = [s.split() for s in open(fname)]
   if n >= 0:
     return [s for s in lst if len(s) <= n]
   else:
     return lst
+
+def bracket_sentences_n(fname, n, punc_pos, keep_punc=[]):
+
+  open_p = 0
+  current = []
+  lst = []
+  sent = []
+
+  for line in open(fname):
+    for ch in line:
+      if ch == '(':
+        current = []
+        open_p += 1
+
+      elif ch == ')':
+        s = ''.join(current).strip()
+
+        if len(s) > 0:
+          pos, tok = s.split()
+          if tok in keep_punc or pos not in punc_pos:
+            sent.append(pos)
+
+          current = []
+
+        open_p -= 1
+        if open_p == 0:
+          lst.append(sent)
+          sent = []
+    
+      else:
+        current.append(ch)
+
+  if n >= 0:
+    return [s for s in lst if len(s) <= n]
+
+  else:
+    return lst
+
+WSJ_PUNC_POS = ['.',',',"''",'``',':','(',')', '-NONE-', '-LRB-', '-RRB-', '$', '#']
+
+def sentences_n_fmt(fname, n=-1, fmt='txt'):
   
+  if fmt == 'txt':
+    return sentences_n(fname, n)
+
+  elif fmt == 'wsj':
+    return bracket_sentences_n(fname, n, WSJ_PUNC_POS)
+
 
 if __name__ == '__main__':
  
   op.add_option('-A', '--output_all', action='store_true')
   op.add_option('-n', '--len_trn', type='int', default=10)
   op.add_option('-N', '--len_tst', type='int', default=10)
+  op.add_option('-f', '--fmt_trn', default='txt')
+  op.add_option('-F', '--fmt_tst', default='txt')
   op.add_option('-I', '--em_iter', type='int', default=10)
   op.add_option('-2', '--ccm2', action='store_true')
 
   opt, args = op.parse_args()
 
-  train = sentences_n(opt.train, opt.len_trn)
+  train_f = sum([glob(t) for t in opt.train.split()], [])
+  train = sum([sentences_n_fmt(f, opt.len_trn, opt.fmt_trn) for f in train_f], [])
+
   if opt.test:
-    test = sentences_n(opt.test, opt.len_tst)
+    test_f = sum([glob(t) for t in opt.test.split()], [])
+    test = sum([sentences_n_fmt(f, opt.len_tst, opt.fmt_tst) for f in test_f], [])
+
   else:
     test = train
 
 
   ccm = (opt.ccm2 and CCM2 or CCM)(train)
   for it in range(opt.em_iter):
+    print >>log_fh, '%2d' % it,
     ccm.step()
     if opt.output and opt.output_all:
       out_fh = open('%s.%02d' % (opt.output, it), 'w')
       ccm.parse_batch(test, out_fh)
-    out_fh.close()
-      
+      out_fh.close()
+  
+  if test is not train:
+    print >>log_fh, 'test corpus:', len(test), 
+    print >>log_fh, 'sentences,', sum(map(len, test)), 'tokens'
+
   out_fh = opt.output and open(opt.output, 'w') or stdout
   ccm.parse_batch(test, out_fh)
-  if opt.output is not None: out_fh.close()
+  if opt.output is not None: 
+    out_fh.close()