Source

dag-discovery / src / discovery_tonfa.py

Full commit
# Efficient non-chatty discovery of nodes missing in one DAG vs another.
#
# Copyright 2009 Peter Arrenbrecht <peter@arrenbrecht.ch>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2, incorporated herein by reference.

import random
from testing import DiscoveryTests, assertnodes

MAX_SAMPLE = 200
TRACE = False


class Config(object):

    def __init__(self):
        pass


class Participant(object):

    def __init__(self, dag, writer, cfg):
        self.dag = dag
        self.writer = writer
        self.cfg = cfg


def log2(n):
    i = 0
    while n > 0:
        n //= 2
        i += 1
    return i


class Client(Participant):

    def __init__(self, dag, writer, cfg):
        Participant.__init__(self, dag, writer, cfg)

    def commonheads(self, server):

        dag = self.dag
        nodes = dag.nodeset()

        undecided = nodes # own nodes where I don't know if the server knows them
        common = set() # nodes we both know
        missing = set() # own nodes the server lacks

        self.writer.step("sampling")
        sample = self._sample(undecided, common)

        # first roundtrip queries server's heads too
        self.writer.step("querying")
        i = 1
        srvheads = set(server.heads())
        yesno, allremaining = server.discover(sample)
        self.writer.done()

        if not (srvheads - nodes):

            # all server's heads known
            self.writer.show("all server heads known")
            result = srvheads

        else:

            common.update(dag.ancestors(srvheads.intersection(undecided)))
            undecided.difference_update(common)

            while undecided:

                self.writer.show("still undecided: %i, sample size: %s"
                                 % (len(undecided), len(sample)), emptyline=True)

                self.writer.step("updating common")
                commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
                common.update(dag.ancestors(commoninsample, common))

                if allremaining is not None:
                    self.writer.step("updating common for allremaining")
                    commonremain = set(allremaining).intersection(undecided)
                    common.update(dag.ancestors(commonremain, common))
                    break

                self.writer.step("updating missing")
                missinginsample = set(n for i, n in enumerate(sample) if not yesno[i])
                missing.update(dag.descendants(missinginsample, missing))

                self.writer.step("updating undecided")
                undecided.difference_update(missing)
                undecided.difference_update(common)

                if not undecided:
                    break

                self.writer.step("sampling")
                sample = self._sample(undecided, common)
                self.writer.step("querying")
                i += 1
                yesno, allremaining = server.discover(sample)
                self.writer.done()

            result = dag.headsofconnectedset(common)

        self.writer.done()
        self.writer.show("number of iterations: %i" % i)
        return result

    def _sample(self, nodes, stop):
        if len(nodes) <= MAX_SAMPLE:
            return set(nodes)


        dag = self.dag

        self.writer.indent()
        self.writer.step("headsof")
        heads = dag.headsofconnectedset(nodes)
        self.writer.done()


        def bfs(start, parents):
            sample = set()
            dist = {}
            order = []
            visit = list(start)
            seen = set()

            roots = set()

            while visit:
                curr = visit.pop(0)
                if curr in seen:
                    continue
                d = dist.setdefault(curr, 1)
                order.append(curr)
                seen.add(curr)

                isroot = True
                for p in parents(curr):
                    if p not in nodes:
                        continue
                    dist.setdefault(p, d+1)
                    visit.append(p)
                    isroot = False
                if isroot:
                    roots.add(curr)

            self.writer.step("sample")
            factor = 1
            for n in order:
                if dist[n] > factor:
                    factor *= 2
                if dist[n] == factor:
                    sample.add(n)
            return sample, roots

        self.writer.step("heads -> roots")
        downsample, roots = bfs(heads, dag.parents)

        self.writer.step("roots -> heads")
        upsample, roots = bfs(roots, dag.children)

        self.writer.step("finalize sample")
        sample = upsample | downsample
        assert sample
        sample.difference_update(heads)
        if len(sample)+len(heads) > MAX_SAMPLE:
            sample = set(random.sample(sample, MAX_SAMPLE - len(heads)))
        elif len(sample)+len(heads) < 200:
            more = MAX_SAMPLE - len(sample) - len(heads)
            self.writer.step("filling with %d random samples" % more)
            sample.update(random.sample(nodes - sample - heads, more))
        sample.update(heads)

        self.writer.unindent()
        return sample


class Server(Participant):

    def __init__(self, dag, writer, cfg):
        Participant.__init__(self, dag, writer, cfg)

    def heads(self):
        return self.dag.heads()

    def discover(self, sample):
        dag = self.dag
        nodes = dag.nodeset()

        yesno = [False for i in xrange(len(sample))]
        known = set()
        for i, n in enumerate(sample):
            if n in nodes:
                known.add(n)
                yesno[i] = True

        allremaining = nodes - dag.ancestors(known)
        self.writer.show("server remaining: %i" % len(allremaining))
        if len(allremaining) > MAX_SAMPLE:
            allremaining = None

        return yesno, allremaining


class Tests(DiscoveryTests):

    def __init__(self):
        DiscoveryTests.__init__(self, quiet=False)
        self.cfg = Config()

    def setupdags(self, a, b, ans, bns):
        self.expected = a.headsofconnectedset(ans & bns)

    def test(self, cdag, sdag, cns, sns):
        s = Server(sdag, self.writer, self.cfg)
        c = Client(cdag, self.writer, self.cfg)

        notraffic = False

        self.writer.section("traffic", quiesce=notraffic)
        actual = c.commonheads(s)
        self.writer.unindent()
        assertnodes(list(self.expected), list(actual))

if __name__ == "__main__":
    random.seed(42)
    Tests().testall()