Commits

Benoit Boissinot committed a6d4537 Merge

merge with parren

Comments (0)

Files changed (2)

     '''generic interface for DAGs'''
 
     def __init__(self):
+        self._children = None
         pass
     def heads(self):
         '''return array of head ids'''
     def parents(self, id):
         '''return array of parents ids of id'''
         return []
+    def children(self, id):
+        '''return array of child ids of id'''
+        if self._children is None:
+            cs = {}
+            for n in self.walk():
+                ps = self.parents(n)
+                for p in ps:
+                    cs.setdefault(p, []).append(n)
+            self._children = cs
+        return self._children.get(id, [])
 
     def subset(self, heads):
         '''return sub-dag with only the given heads and their ancestors'''
         '''return set of nodes from heads to stops (or root)'''
         return set(self.walk(heads, stops))
 
-    def headsof(self, nodes):
+    def headsof(self, nodes, stops=None):
         '''return subset of nodes where no node has a descendant in nodes'''
         hds = set(nodes)
         if not hds:

src/discovery_tonfa.py

 from testing import DiscoveryTests, assertnodes
 
 MAX_SAMPLE = 200
+TRACE = False
+
 
 class Config(object):
 
 def clever_sample(dag, nodes, stop):
     if len(nodes) < MAX_SAMPLE:
         return set(nodes)
-    heads = dag.headsof(nodes)
+
+    if TRACE: print "headsof"
+    heads = dag.headsof(nodes, stop)
+    if TRACE: print heads
     sample = set()
 
     dist = {}
     order = []
     visit = list(heads)
     seen = set(stop)
+    cands = []
 
+    if TRACE: print "heads -> roots"
     roots = set()
 
     while visit:
         d = dist.setdefault(curr, 1)
         order.append(curr)
         seen.add(curr)
+        cands.append(curr)
 
         if not len(list(p for p in dag.parents(curr) if p not in stop)):
             roots.add(curr)
             dist.setdefault(p, d+1)
             visit.append(p)
 
+    if TRACE: print "sample"
     factor = 1
     for n in order:
         if dist[n] > factor:
         if dist[n] == factor:
             sample.add(n)
 
-    children = {}
-    for n in dag.walk():
-        ps = dag.parents(n)
-        for p in ps:
-            children.setdefault(p, []).append(n)
-
+    if TRACE: print "roots -> heads"
     visit = list(roots)
     order = []
     dist = {}
         d = dist.setdefault(curr, 1)
         order.append(curr)
         seen.add(curr)
-        for c in children.get(curr, []):
+        for c in dag.children(curr):
             if c not in nodes:
                 continue
             dist.setdefault(c, d+1)
             visit.append(c)
 
+    if TRACE: print "sample"
     factor = 1
     for n in order:
         if dist[n] > factor:
     sample.difference_update(heads)
     if len(sample)+len(heads) > MAX_SAMPLE:
         sample = set(random.sample(sample, MAX_SAMPLE-len(heads)))
+    elif len(sample)+len(heads) < 200:
+        if TRACE: print "Filling from", len(sample) + len(heads)
+        sample.update(random.sample(list(set(cands) - sample - heads), 200 - len(sample) - len(heads)))
     sample.update(heads)
     return sample
 
     def common(self, server):
         i = 0
         while self._unknown:
+            if TRACE: self.writer.show("sampling...")
             sample = clever_sample(self.dag, self._unknown, self._common)
+            if TRACE: self.writer.show("querying...")
             common, remain = server.discover(sample)
             self.writer.show("number of unknown left: %i, sample size: %s"
                              % (len(self._unknown), len(sample)))
             i += 1
 
+            if TRACE: self.writer.show("updating missing...")
             self._missing.update(self.dag.descendants((n for n in sample if n not in common), self._missing))
+            if TRACE: self.writer.show("updating common...")
             self._common.update(self.dag.ancestors(list(n for n in common if n in self._unknown), self._common))
 
             if remain:
+                if TRACE: self.writer.show("updating missing...")
                 self._missing.update(self.dag.descendants((n for n in self._unknown if n not in remain), self._missing))
+                if TRACE: self.writer.show("updating common...")
                 self._common.update(self.dag.ancestors(list(n for n in remain if n in self._unknown)), self._common)
                 break
 
+            if TRACE: self.writer.show("updating unknown...")
             self._unknown.difference_update(self._missing)
             self._unknown.difference_update(self._common)