Source

dag-discovery / src / dagutil.py

Full commit
# DAG utilities
#
# Copyright 2009 Peter Arrenbrecht <peter@arrenbrecht.ch>
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2, incorporated herein by reference.

import dagparser

BREAK = True

class DAG(object):
    '''generic interface for DAGs'''

    def __init__(self):
        self._children = None
        pass
    def heads(self):
        '''return array of head ids'''
        return []
    def parents(self, id):
        '''return array of parents ids of id'''
        return []
    def children(self, id):
        '''return array of child ids of id'''
        if self._children is None:
            cs = {}
            for n in self.walk():
                ps = self.parents(n)
                for p in ps:
                    cs.setdefault(p, []).append(n)
            self._children = cs
        return self._children.get(id, [])

    def subset(self, heads):
        '''return sub-dag with only the given heads and their ancestors'''
        return SubDAG(self, heads)

    def nodeset(self, heads=None, stops=None):
        '''return set of nodes from heads to stops (or root)'''
        return set(self.walk(heads, stops))

    def headsof(self, nodes):
        '''return subset of nodes where no node has a descendant in nodes'''
        hds = set(nodes)
        seen = set()
        if not hds:
            return hds
        for n in sorted(nodes, reverse=True):
            if n in hds:
                ps = self.parents(n)
                if ps:
                    ancestors = self.nodeset(heads=ps, stops=seen)
                    seen.update(ancestors)
                    hds.difference_update(ancestors)
        assert hds
        return hds

    def headsofconnectedset(self, nodes):
        '''return subset of connected set so that no node has a descendant in it

        By "connected set" we mean that if an ancestor and a descendant are in
        the set, then so is at least one path connecting them.'''
        hds = set(nodes)
        if not hds:
            return hds
        for n in nodes:
            for p in self.parents(n):
                hds.discard(p)
        assert hds
        return hds

    def _relatives(self, relativesfn, start, stop=None):
        if stop is not None:
            seen = set(stop)
        else:
            seen = set()
        rels = set()
        pending = list(start)
        while pending:
            n = pending.pop()
            if n not in seen:
                rels.add(n)
                seen.add(n)
                pending.extend(relativesfn(n))
        return rels

    def descendants(self, start, stop=None):
        return self._relatives(self.children, start, stop)

    def ancestors(self, start, stop=None):
        return self._relatives(self.parents, start, stop)

    def walk(self, heads=None, stops=None):
        '''iterate ids from heads to stops (or root), depth-first'''
        pending = list(heads or self.heads())
        seen = set(stops or [])
        while pending:
            id = pending.pop()
            while id not in seen:
                seen.add(id)
                if (yield id):
                    break
                ps = self.parents(id)
                if not ps:
                    break
                id = ps[0]
                pending += ps[1:]

    def reachable(self, heads, node):
        pending = heads[:]
        seen = set([node])
        while pending:
            id = pending.pop()
            while id not in seen:
                seen.add(id)
                ps = self.parents(id)
                if not ps:
                    break
                id = ps[0]
                pending += ps[1:]
            if id == node:
                return True
        return False

    def fmtwalk(self, charfn=None, textfn=None, heads=None, stops=None):
        '''return fmtwalk for dagprinter.printfmtwalk()'''
        return self.fmtidwalk(self.walk(heads, stops), charfn, textfn)

    def fmtidwalk(self, ids, charfn=None, linesfn=None):
        '''return fmtwalk for dagprinter.printfmtwalk() from id generator'''
        if not charfn:
            def charfn(id): return 'o'
        if not linesfn:
            def linesfn(id): return ["%d" % id]
        for id in reversed(sorted(id for id in ids)):
            yield id, charfn(id), linesfn(id), self.parents(id)

    def asparse(self, ids):
        labels = self.labels
        lookup = dict(((id, name) for name in labels for id in labels[name]))
        for id in sorted([i for i in ids]):
            ps = self.parents(id)
            yield 'n', (id, ps)
            n = lookup.get(id)
            if n:
                yield 'l', (id, n)

    def descgen(self, ids):
        '''generate desc as sequence of strings from id walk'''
        return dagparser.dagtextlines(self.asparse(ids))

    def desc(self, heads=None, stops=None):
        '''return desc as parseable by MemDAG.fromdesc()'''
        return "\n".join(self.descgen(self.walk(heads, stops)))

class SubDAG(DAG):
    '''subdag of an existing DAG'''
    def __init__(self, dag, heads):
        DAG.__init__(self)
        self._dag = dag
        self._heads = heads
    def heads(self):
        return self._heads
    def parents(self, id):
        return self._dag.parents(id)

class MemDAG(DAG):
    '''in-memory DAG for testing; ids must be 0..n'''
    def __init__(self, heads, parentsarr, labels=None):
        DAG.__init__(self)
        self._heads = heads
        self._parentsarr = parentsarr
        self.labels = labels or {}
    def heads(self):
        return self._heads
    def parents(self, id):
        return self._parentsarr[id]

    @staticmethod
    def fromparse(parse):
        heads = set()
        labels = {}
        pars = []

        for t, d in parse:
            if t == 'n':
                r, ps = d
                pars.append([p for p in ps if p != -1])
                heads.add(r)
                for p in ps:
                    heads.discard(p)
            elif t == 'l':
                r, name = d
                labels.setdefault(name, []).append(r)
            else:
                raise Exception('unknown operation: ' + t)

        return MemDAG(sorted([h for h in heads]), pars, labels)

    @staticmethod
    def fromdesc(desc):
        return MemDAG.fromparse(dagparser.parsedag(desc))

    @staticmethod
    def fromfile(fname):
        '''fromdesc on text read from file'''
        f = file(fname)
        try:
            return MemDAG.fromdesc(f.read())
        finally:
            f.close()

if __name__ == "__main__":
    for n in ("hg", "linux", "mini", "netbeans", "xenbits"):
        fn = "data/" + n + ".dag"
        d = MemDAG.fromfile(fn)
        t = dagparser.dagtext(d.asparse(d.nodeset()))
        f = file(fn + "2", "w")
        try:
            f.write(t)
        finally:
            f.close()