Peter Arrenbrecht avatar Peter Arrenbrecht committed aa7ec95

make _sample more symmetric and use it for server's sample too

* add and use dag.inverse()
* make server's sample use _sample from roots only
* simplify commonheads a bit

Comments (0)

Files changed (2)

     '''generic interface for DAGs'''
 
     def __init__(self):
-        self._children = None
+        self._inverse = None
+        self.headsarehigher = True
         pass
+
     def heads(self):
         '''return array of head ids'''
         return []
+
     def parents(self, id):
         '''return array of parents ids of id'''
         return []
-    def children(self, id):
-        '''return array of child ids of id'''
-        if self._children is None:
-            cs = {}
-            for n in self.walk():
-                ps = self.parents(n)
-                for p in ps:
-                    cs.setdefault(p, []).append(n)
-            self._children = cs
-        return self._children.get(id, [])
+
+    def inverse(self):
+        '''return inverse DAG, where parents becomes children, etc.'''
+        if self._inverse is None:
+            self._inverse = InverseDAG(self)
+        return self._inverse
 
     def subset(self, heads):
         '''return sub-dag with only the given heads and their ancestors'''
         return SubDAG(self, heads)
 
-    def nodeset(self, heads=None, stops=None):
-        '''return set of nodes from heads to stops (or root)'''
-        return set(self.walk(heads, stops))
+    def ancestors(self, start, stop=None):
+        '''return all ancestors of nodes in start, but stop DAG walk at stop'''
+        if stop is not None:
+            seen = set(stop)
+        else:
+            seen = set()
+        anc = []
+        pending = list(start)
+        while pending:
+            n = pending.pop()
+            if n not in seen:
+                anc.append(n)
+                seen.add(n)
+                pending.extend(self.parents(n))
+        return anc
+
+    def descendants(self, start, stop=None):
+        return self.inverse().ancestors(start, stop)
 
     def headsof(self, nodes):
         '''return subset of nodes where no node has a descendant in nodes'''
             if n in hds:
                 ps = self.parents(n)
                 if ps:
-                    ancestors = self.nodeset(heads=ps, stops=seen)
-                    seen.update(ancestors)
-                    hds.difference_update(ancestors)
+                    anc = self.ancestors(ps, stops=seen)
+                    seen.update(anc)
+                    hds.difference_update(anc)
         assert hds
         return hds
 
         assert hds
         return hds
 
-    def _relatives(self, relativesfn, start, stop=None):
-        if stop is not None:
-            seen = set(stop)
-        else:
-            seen = set()
-        rels = set()
-        pending = list(start)
-        while pending:
-            n = pending.pop()
-            if n not in seen:
-                rels.add(n)
-                seen.add(n)
-                pending.extend(relativesfn(n))
-        return rels
-
-    def descendants(self, start, stop=None):
-        return self._relatives(self.children, start, stop)
-
-    def ancestors(self, start, stop=None):
-        return self._relatives(self.parents, start, stop)
+    def nodeset(self, heads=None, stops=None):
+        '''return set of nodes from heads to stops (or root)'''
+        if heads is None:
+            heads = self.heads()
+        return set(self.ancestors(heads, stops))
 
     def walk(self, heads=None, stops=None):
         '''iterate ids from heads to stops (or root), depth-first'''
         '''return desc as parseable by MemDAG.fromdesc()'''
         return "\n".join(self.descgen(self.walk(heads, stops)))
 
+class InverseDAG(DAG):
+
+    def __init__(self, orig):
+        DAG.__init__(self)
+        self.headsarehigher = False
+        roots = []
+        cs = {}
+        for n in orig.walk():
+            ps = orig.parents(n)
+            if ps:
+                for p in ps:
+                    cs.setdefault(p, []).append(n)
+            else:
+                roots.append(n)
+        self._roots = roots
+        self._children = cs
+
+    def heads(self):
+        return self._roots
+
+    def parents(self, id):
+        return self._children.get(id, [])
+
 class SubDAG(DAG):
     '''subdag of an existing DAG'''
     def __init__(self, dag, heads):

src/discovery_tonfa.py

 TRACE = False
 
 
+def log2(n):
+    i = 0
+    while n > 0:
+        n //= 2
+        i += 1
+    return i
+
+
 class Config(object):
 
     def __init__(self):
         self.writer = writer
         self.cfg = cfg
 
+    def _sample(self, nodes, stop, fromheads=True, fromroots=True):
 
-def log2(n):
-    i = 0
-    while n > 0:
-        n //= 2
-        i += 1
-    return i
+        if len(nodes) <= MAX_SAMPLE:
+            return set(nodes)
+
+        sample = set()
+        always = set()
+
+        self.writer.indent(quiesce=False)
+
+        dags = []
+        if fromheads:
+            dags.append(self.dag)
+        if fromroots:
+            dags.append(self.dag.inverse())
+
+        for dag in dags:
+
+            self.writer.step("headsof")
+            heads = dag.headsofconnectedset(nodes)
+
+            self.writer.step("walk")
+            dist = {}
+            order = []
+            visit = list(heads)
+            seen = set(stop)
+            while visit:
+                curr = visit.pop(0)
+                if curr in seen:
+                    continue
+                d = dist.setdefault(curr, 1)
+                order.append(curr)
+                seen.add(curr)
+                for p in dag.parents(curr):
+                    dist.setdefault(p, d + 1)
+                    visit.append(p)
+
+            self.writer.step("sample")
+            factor = 1
+            for n in order:
+                d = dist[n]
+                if d > factor:
+                    factor *= 2
+                if d == factor:
+                    sample.add(n)
+
+            if not always:
+                always.update(heads)
+
+        self.writer.step("finalize")
+        assert sample
+        sample.difference_update(always)
+        desiredlen = MAX_SAMPLE - len(always)
+        if len(sample) > desiredlen:
+            sample = set(random.sample(sample, desiredlen))
+        elif len(sample) < desiredlen:
+            more = desiredlen - len(sample)
+            self.writer.step("filling with %d random samples" % more)
+            sample.update(random.sample(list(nodes - sample - heads), more))
+        sample.update(always)
+
+        self.writer.unindent()
+        return sample
 
 
 class Client(Participant):
 
     def __init__(self, dag, writer, cfg):
         Participant.__init__(self, dag, writer, cfg)
+        self.roundtrips = 0
 
     def commonheads(self, server):
 
         self.writer.step("querying")
         i = 1
         srvheads = set(server.heads())
-        yesno, allremaining = server.discover(sample)
+        yesno, srvsample, noneremain = server.discover(sample)
         self.writer.done()
 
         if not (srvheads - nodes):
                 self.writer.step("updating common")
                 commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
                 common.update(dag.ancestors(commoninsample, common))
+                if srvsample:
+                    commoninsample = set(srvsample).intersection(undecided)
+                    common.update(dag.ancestors(commoninsample, common))
 
-                if allremaining is not None:
-                    self.writer.step("updating common for allremaining")
-                    commonremain = set(allremaining).intersection(undecided)
-                    common.update(dag.ancestors(commonremain, common))
+                if noneremain:
+                    # server sent all remaining possibly undecided nodes in sample
                     break
 
                 self.writer.step("updating missing")
                 sample = self._sample(undecided, common)
                 self.writer.step("querying")
                 i += 1
-                yesno, allremaining = server.discover(sample)
+                yesno, srvsample, noneremain = server.discover(sample)
                 self.writer.done()
 
             result = dag.headsofconnectedset(common)
 
         self.writer.done()
-        self.writer.show("number of iterations: %i" % i)
+        self.roundtrips = i
         return result
 
-    def _sample(self, nodes, stop):
-        if len(nodes) <= MAX_SAMPLE:
-            return set(nodes)
-
-        dag = self.dag
-
-        self.writer.indent()
-        self.writer.step("headsof")
-        heads = dag.headsofconnectedset(nodes)
-        self.writer.done()
-
-        sample = set()
-
-        dist = {}
-        order = []
-        visit = list(heads)
-        seen = set(stop)
-        cands = []
-
-        self.writer.step("heads -> roots")
-        roots = set()
-
-        while visit:
-            curr = visit.pop(0)
-            if curr in seen:
-                continue
-            d = dist.setdefault(curr, 1)
-            order.append(curr)
-            seen.add(curr)
-            cands.append(curr)
-
-            if not len(list(p for p in dag.parents(curr) if p not in stop)):
-                roots.add(curr)
-
-            for p in dag.parents(curr):
-                dist.setdefault(p, d+1)
-                visit.append(p)
-
-        self.writer.step("sample")
-        factor = 1
-        for n in order:
-            if dist[n] > factor:
-                factor *= 2
-            if dist[n] == factor:
-                sample.add(n)
-
-        self.writer.step("roots -> heads")
-        visit = list(roots)
-        order = []
-        dist = {}
-        seen = set()
-        while visit:
-            curr = visit.pop(0)
-            if curr in seen:
-                continue
-            d = dist.setdefault(curr, 1)
-            order.append(curr)
-            seen.add(curr)
-            for c in dag.children(curr):
-                if c not in nodes:
-                    continue
-                dist.setdefault(c, d+1)
-                visit.append(c)
-
-        self.writer.step("sample")
-        factor = 1
-        for n in order:
-            if dist[n] > factor:
-                factor *= 2
-            if dist[n] == factor:
-                sample.add(n)
-
-        self.writer.step("finalize sample")
-        assert sample
-        sample.difference_update(heads)
-        if len(sample)+len(heads) > MAX_SAMPLE:
-            sample = set(random.sample(sample, MAX_SAMPLE - len(heads)))
-        elif len(sample)+len(heads) < 200:
-            more = MAX_SAMPLE - len(sample) - len(heads)
-            self.writer.step("filling with %d random samples" % more)
-            sample.update(random.sample(list(set(cands) - sample - heads), more))
-        sample.update(heads)
-
-        self.writer.unindent()
-        return sample
-
 
 class Server(Participant):
 
 
     def discover(self, sample):
         dag = self.dag
+
+        self.writer.indent()
+
+        self.writer.step("yesno")
         nodes = dag.nodeset()
-
         yesno = [False for i in xrange(len(sample))]
         known = set()
         for i, n in enumerate(sample):
                 known.add(n)
                 yesno[i] = True
 
-        allremaining = nodes - dag.ancestors(known)
+        self.writer.step("remaining")
+        common = set(dag.ancestors(known))
+        allremaining = nodes - common
+        self.writer.done()
+
         self.writer.show("server remaining: %i" % len(allremaining))
+        self.writer.unindent()
+
         if len(allremaining) > MAX_SAMPLE:
-            allremaining = None
+            mysample = self._sample(allremaining, common, fromheads=False)
+            return yesno, mysample, False
 
-        return yesno, allremaining
+        return yesno, allremaining, True
 
 
 class Tests(DiscoveryTests):
         actual = c.commonheads(s)
         self.writer.unindent()
         assertnodes(list(self.expected), list(actual))
+        self.writer.show("number of iterations: %i" % c.roundtrips)
 
 if __name__ == "__main__":
     random.seed(0)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.