Source

dag-discovery / src / discovery_tonfa.py

Full commit
# Efficient non-chatty discovery of nodes missing in one DAG vs another.
#
# Copyright 2009 Peter Arrenbrecht <peter@arrenbrecht.ch>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2, incorporated herein by reference.

import random
from testing import DiscoveryTests, assertnodes

MAX_SAMPLE = 200
TRACE = False


class Config(object):

    def __init__(self):
        pass


class Participant(object):

    def __init__(self, dag, writer, cfg):
        self.dag = dag
        self.writer = writer
        self.cfg = cfg


def log2(n):
    i = 0
    while n > 0:
        n //= 2
        i += 1
    return i

def clever_sample(dag, nodes, stop):
    if len(nodes) < MAX_SAMPLE:
        return set(nodes)

    if TRACE: print "headsof"
    heads = dag.headsofconnectedset(nodes)
    if TRACE: print heads
    sample = set()

    dist = {}
    order = []
    visit = list(heads)
    seen = set(stop)
    cands = []

    if TRACE: print "heads -> roots"
    roots = set()

    while visit:
        curr = visit.pop(0)
        if curr in seen:
            continue
        d = dist.setdefault(curr, 1)
        order.append(curr)
        seen.add(curr)
        cands.append(curr)

        if not len(list(p for p in dag.parents(curr) if p not in stop)):
            roots.add(curr)

        for p in dag.parents(curr):
            dist.setdefault(p, d+1)
            visit.append(p)

    if TRACE: print "sample"
    factor = 1
    for n in order:
        if dist[n] > factor:
            factor *= 2
        if dist[n] == factor:
            sample.add(n)

    if TRACE: print "roots -> heads"
    visit = list(roots)
    order = []
    dist = {}
    seen = set()
    while visit:
        curr = visit.pop(0)
        if curr in seen:
            continue
        d = dist.setdefault(curr, 1)
        order.append(curr)
        seen.add(curr)
        for c in dag.children(curr):
            if c not in nodes:
                continue
            dist.setdefault(c, d+1)
            visit.append(c)

    if TRACE: print "sample"
    factor = 1
    for n in order:
        if dist[n] > factor:
            factor *= 2
        if dist[n] == factor:
            sample.add(n)

    assert sample
    sample.difference_update(heads)
    if len(sample)+len(heads) > MAX_SAMPLE:
        sample = set(random.sample(sample, MAX_SAMPLE-len(heads)))
    elif len(sample)+len(heads) < 200:
        if TRACE: print "Filling from", len(sample) + len(heads)
        sample.update(random.sample(list(set(cands) - sample - heads), 200 - len(sample) - len(heads)))
    sample.update(heads)
    return sample

class Client(Participant):

    def __init__(self, dag, writer, cfg):
        Participant.__init__(self, dag, writer, cfg)

    def commonheads(self, server):

        dag = self.dag
        nodes = dag.nodeset()

        undecided = nodes # own nodes where I don't know if the server knows them
        common = set() # nodes we both know
        missing = set() # own nodes the server lacks

        self.writer.step("sampling")
        sample = clever_sample(dag, undecided, common)

        # first roundtrip queries server's heads too
        self.writer.step("querying")
        i = 1
        srvheads = set(server.heads())
        yesno, allremaining = server.discover(sample)
        self.writer.done()

        if not (srvheads - nodes):

            # all server's heads known
            self.writer.show("all server heads known")
            result = srvheads

        else:

            common.update(dag.ancestors(srvheads.intersection(undecided)))
            undecided.difference_update(common)

            while undecided:

                self.writer.show("still undecided: %i, sample size: %s"
                                 % (len(undecided), len(sample)))

                self.writer.step("updating common")
                commoninsample = set(n for i, n in enumerate(sample) if yesno[i])
                common.update(dag.ancestors(commoninsample, common))

                if allremaining is not None:
                    self.writer.step("updating common for allremaining")
                    commonremain = set(allremaining).intersection(undecided)
                    common.update(dag.ancestors(commonremain, common))
                    break

                self.writer.step("updating missing")
                missinginsample = set(n for i, n in enumerate(sample) if not yesno[i])
                missing.update(dag.descendants(missinginsample, missing))

                self.writer.step("updating undecided")
                undecided.difference_update(missing)
                undecided.difference_update(common)

                if not undecided:
                    break

                self.writer.step("sampling")
                sample = clever_sample(dag, undecided, common)
                self.writer.step("querying")
                i += 1
                yesno, allremaining = server.discover(sample)
                self.writer.done()

            result = dag.headsofconnectedset(common)

        self.writer.done()
        self.writer.show("number of iterations: %i" % i)
        return result


class Server(Participant):

    def __init__(self, dag, writer, cfg):
        Participant.__init__(self, dag, writer, cfg)

    def heads(self):
        return self.dag.heads()

    def discover(self, sample):
        dag = self.dag
        nodes = dag.nodeset()

        yesno = [False for i in xrange(len(sample))]
        known = set()
        for i, n in enumerate(sample):
            if n in nodes:
                known.add(n)
                yesno[i] = True

        allremaining = nodes - dag.ancestors(known)
        self.writer.show("server remaining: %i" % len(allremaining))
        if len(allremaining) > MAX_SAMPLE:
            allremaining = None

        return yesno, allremaining


class Tests(DiscoveryTests):

    def __init__(self):
        DiscoveryTests.__init__(self, quiet=False)
        self.cfg = Config()

    def setupdags(self, a, b, ans, bns):
        self.expected = a.headsofconnectedset(ans & bns)

    def test(self, cdag, sdag, cns, sns):
        s = Server(sdag, self.writer, self.cfg)
        c = Client(cdag, self.writer, self.cfg)

        notraffic = False

        self.writer.section("traffic", quiesce=notraffic)
        actual = c.commonheads(s)
        self.writer.unindent()
        assertnodes(list(self.expected), list(actual))

if __name__ == "__main__":
    random.seed(0)
    Tests().testall()