Commits

Chris Mutel committed 81108ae

Initial commit. Functionality exists; needs documentation

Comments (0)

Files changed (17)

+template tags and filters
+pygraph
+pyrecursion
+VERSION = (0, 0.1, None)

dag/content_type_utils.py

+# Utilities for working with generic relations
+# From soclone project: http://code.google.com/p/soclone/source/browse/trunk/soclone/utils/models.py
+
+from django.contrib.contenttypes.models import ContentType
+import itertools
+
+# t = TransitiveClosure.objects.filter(**kwargs).values('node_from', 'node_to', 'depth')
+# n = Node.objects.filter(id__in=[x['node_to'] for x in t])
+
+def fetch_model_dict(model, ids, fields=None):
+    """
+    Fetches a dict of model details for model instances with the given
+    ids, keyed by their id.
+
+    If a fields list is given, a dict of details will be retrieved for
+    each model, otherwise complete model instances will be retrieved.
+
+    Any fields list given shouldn't contain the primary key attribute for
+    the model, as this can be determined from its Options.
+    """
+    if fields is None:
+        return model._default_manager.in_bulk(ids)
+    else:
+        id_attr = model._meta.pk.attname
+        return dict((obj[id_attr], obj) for obj
+            in model._default_manager.filter(id__in=ids).values(
+                *itertools.chain((id_attr,), fields)))
+
+def populate_content_object_caches(generic_related_objects, model_fields=None):
+    """
+    Retrieves ``ContentType`` and content objects for the given list of
+    items which use a generic relation, grouping the retrieval of content
+    objects by model to reduce the number of queries executed.
+
+    This results in ``number_of_content_types + 1`` queries rather than
+    the ``number_of_generic_reL_objects * 2`` queries you'd get by
+    iterating over the list and accessing each item's object attribute.
+
+    If a dict mapping model classes to field names is given, only the
+    given fields will be looked up for each model specified and the
+    object cache will be populated with a dict of the specified fields.
+    Otherwise, complete model instances will be retrieved.
+
+    """
+    if model_fields is None:
+        model_fields = {}
+
+    # Group content object ids by their content type ids
+    ids_by_content_type = {}
+    for obj in generic_related_objects:
+        ids_by_content_type.setdefault(obj.content_type_id,
+                                       []).append(obj.object_id)
+
+    # Retrieve content types and content objects in bulk
+    content_types = ContentType.objects.in_bulk(ids_by_content_type.keys())
+    for content_type_id, ids in ids_by_content_type.iteritems():
+        model = content_types[content_type_id].model_class()
+        #TODO: objects is not yet defined
+        objects[content_type_id] = fetch_model_dict(
+            model, tuple(set(ids)), model_fields.get(model, None))
+
+    # Set content types and content objects in the appropriate cache
+    # attributes, so accessing the 'content_type' and 'object' attributes
+    # on each object won't result in further database hits.
+    for obj in generic_related_objects:
+        obj._object_cache = objects[obj.content_type_id][obj.object_id]
+        obj._content_type_cache = content_types[obj.content_type_id]

dag/docs/rst/index.rst

+Introduction
+============
+
+* django-acyclic * provides handling of directed acyclic graphs of generic objects in the Django web framework. Currently, only PostgreSQL is supported.
+
+Simple Example
+--------------
+
+::
+
+    from dag.mixins import GraphMixin
+    from dag.models import Edge, Graph
+    
+    class Continent(GraphMixin):
+        name = models.TextField()
+    
+    class Country(GraphMixin):
+        name = models.TextField()
+    
+    class City(GraphMixin):
+        name = models.TextField()
+    
+    graph = Graph.objects.create(name="example graph")
+    Europe = Continent.objects.create(name="Europe")
+    Switzerland = Country.objects.create(name="Switzerland")
+    Zurich = City.objects.create(name="Zurich")
+    # One way to create edges
+    Edge.objects.create(node_from=Europe, node_to=Switzerland)
+    # Another way to create edges
+    Switzerland.link_to(Zurich)
+    Europe.get_descendants() == [Switzerland, Zurich]
+    >>> True
+    graph.is_heterogeneous
+    >>> True
+    graph.root_node_objects
+    >>> [Europe,]
+
+Prerequisites
+-------------
+
+- Django >= 1.1
+- PostgreSQL >= 8.0
+
+class CircularReferenceError(StandardError):
+    pass
+
+class SeparateGraphsError(StandardError):
+    pass
+
+class MultipleReferencesError(StandardError):
+    pass
+
+class DuplicateEdgeError(StandardError):
+    pass
+
+class MissingGraphError(StandardError):
+    pass

dag/management/__init__.py

+from django.utils.translation import ugettext as _
+from django.db import connection, transaction
+from django.db.models import signals
+
+def create_tables_and_triggers(sender, **kwargs):
+    # print "Running 'create_tables_and_triggers' function"
+    cursor = connection.cursor()
+    
+    # Test if plpgsql is installed
+    TEST_PLPGSQL_SQL = """select count(*) from pg_language where lanname = 'plpgsql'"""
+    cursor.execute(TEST_PLPGSQL_SQL)
+    plpgsql_installed = cursor.fetchone()[0]
+    if not plpgsql_installed:
+        INSTALL_PLPGSQL_SQL = """CREATE TRUSTED PROCEDURAL LANGUAGE 'plpgsql' HANDLER plpgsql_call_handler VALIDATOR plpgsql_validator"""
+        cursor.execute(INSTALL_PLPGSQL_SQL)
+        transaction.commit_unless_managed()
+    
+    # Exit if already installed
+    # TEST_DAG_EDGE_TABLE_SQL = """select count(*) from information_schema.tables where table_name = 'dag_edge'"""
+    TEST_FUNCTION_EXISTS = """SELECT count(*) FROM pg_proc where proname = 'enforce_acyclicity'"""
+    cursor.execute(TEST_FUNCTION_EXISTS)
+    if cursor.fetchone()[0]:
+        return
+    
+    try:
+        transaction.enter_transaction_management()
+        transaction.managed(True)
+        # Set default depth to 0 for database triggers
+        DEFAULT_DEPTH_SQL = """alter table dag_transitive_closure alter column depth set default 0"""
+        cursor.execute(DEFAULT_DEPTH_SQL)
+    
+        ADD_EDGE_CHECK_SQL = """ALTER TABLE dag_edge ADD CHECK (node_to_id <> node_from_id)"""
+        cursor.execute(ADD_EDGE_CHECK_SQL)
+    
+        TRANSITIVE_CLOSURE_INDEX_1_SQL = """create index idx_trans_from_graph_to on dag_transitive_closure(node_from_id, graph_id, node_to_id)"""
+        cursor.execute(TRANSITIVE_CLOSURE_INDEX_1_SQL)
+    
+        TRANSITIVE_CLOSURE_INDEX_2_SQL = """create index idx_trans_to_graph on dag_transitive_closure(node_to_id, graph_id)"""
+        cursor.execute(TRANSITIVE_CLOSURE_INDEX_2_SQL)
+    
+        ENFORCE_ACYCLICALITY_FUNCTION = """create function enforce_acyclicity() returns trigger as
+        $$
+        begin
+
+        if exists(select 1 from dag_transitive_closure where node_to_id=NEW.node_from_id and node_from_id=NEW.node_to_id and graph_id=NEW.graph_id) then
+            raise exception 'Inserting (%%,%%) will create a loop.', NEW.node_from_id, NEW.node_to_id;
+        end if;
+
+        return NEW;
+
+        end;
+        $$ language plpgsql"""
+        ENFORCE_ACYCLICALITY_TRIGGER = """create trigger trig_enforce_acyclicity before insert on dag_edge
+            for each row execute procedure enforce_acyclicity()"""
+        # TODO: Fix. Gives strange error.
+        cursor.execute(ENFORCE_ACYCLICALITY_FUNCTION)
+        cursor.execute(ENFORCE_ACYCLICALITY_TRIGGER)
+
+        ADD_IMPLIED_EDGES_FUNCTION = """create function add_implied_edges() returns trigger
+            as $$
+            declare
+                id int;
+            begin
+                id := nextval('dag_transitive_closure_t_edge_id_seq');
+                insert into dag_transitive_closure (node_from_id, node_to_id, graph_id, entry_id, direct_id, exit_id, t_edge_id) values (new.node_from_id, new.node_to_id, new.graph_id, id, id, id, id);
+
+                insert into dag_transitive_closure (direct_id, exit_id, entry_id, node_from_id, node_to_id, graph_id, depth)
+
+                    -- Incoming edges.
+                    select id, id, t_edge_id, node_from_id, new.node_to_id, new.graph_id, depth + 1
+                        from dag_transitive_closure
+                        where node_to_id = new.node_from_id
+                              and graph_id=new.graph_id
+
+                    union
+
+                    -- Outgoing edges.
+                    select id, t_edge_id, id, new.node_from_id, node_to_id, new.graph_id, depth + 1
+                        from dag_transitive_closure
+                        where node_from_id = new.node_to_id
+                              and graph_id=new.graph_id
+
+                    union
+
+                    -- Incoming to outgoing.
+                    select a.t_edge_id, id, b.t_edge_id, a.node_from_id, b.node_to_id, new.graph_id, a.depth + b.depth + 1
+                        from dag_transitive_closure a
+                             cross join dag_transitive_closure b
+                        where a.node_to_id = new.node_from_id and b.node_from_id = new.node_to_id
+                              and a.graph_id=new.graph_id and b.graph_id=new.graph_id;
+
+                return null;
+            end;
+            $$ language plpgsql"""
+        ADD_IMPLIED_EDGES_TRIGGER = """create trigger trig_add_implied_edges after insert on dag_edge
+            for each row execute procedure add_implied_edges()"""
+        REMOVE_IMPLIED_EDGES_FUNCTION = """create function remove_implied_edges() returns trigger
+            as $$
+            begin
+                create temporary table purge_list as
+                    -- The direct edge.
+                    select direct_id as t_edge_id
+                    from dag_transitive_closure
+                    where node_from_id = old.node_from_id
+                          and node_to_id = old.node_to_id
+                          and graph_id = old.graph_id;
+                while true
+                loop
+                    insert into purge_list
+                    -- Edges dependant of those in the purge list.
+                    select t_edge_id
+                        from dag_transitive_closure
+                        where
+                            depth > 0
+                            and t_edge_id not in ( select t_edge_id from purge_list )
+                            and (
+                                entry_id in ( select t_edge_id from purge_list )
+                                or exit_id in ( select t_edge_id from purge_list )
+                            );
+                    if not found then
+                        exit;
+                    end if;
+                end loop;
+                delete from dag_transitive_closure
+                    where t_edge_id in (
+                        select t_edge_id
+                        from purge_list
+                    );
+                drop table purge_list;
+                return null;
+            end;
+            $$ language plpgsql"""
+        REMOVE_IMPLIED_EDGES_TRIGGER = """create trigger trig_remove_implied_edges after delete on dag_edge
+            for each row execute procedure remove_implied_edges()"""
+    
+        cursor.execute(ADD_IMPLIED_EDGES_FUNCTION)
+        cursor.execute(ADD_IMPLIED_EDGES_TRIGGER)
+        cursor.execute(REMOVE_IMPLIED_EDGES_FUNCTION)
+        cursor.execute(REMOVE_IMPLIED_EDGES_TRIGGER)
+        transaction.commit()
+        print "Triggers and tables created successfully"
+
+    except:
+        import traceback        
+        print traceback.print_exc()
+        transaction.rollback()
+    
+    finally:
+        transaction.leave_transaction_management()
+
+signals.post_syncdb.connect(create_tables_and_triggers)
+from django.db import models
+from django.contrib.contenttypes.models import ContentType
+from utils import get_objects_for_node
+from models import TransitiveClosure, Node
+from errors import MultipleReferencesError
+
+class GraphMixin(models.Model):
+    """Mixin class that adds methods to define, explore, and manipulate a directed acyclic graph."""
+    class Meta:
+        abstract = True
+    
+    @property
+    def content_type(self):
+        if not hasattr(self, "_content_type"):
+            self._content_type = ContentType.objects.get(
+                app_label=self._meta.app_label, model=self._meta.module_name)
+        return self._content_type
+    
+    def _node(self, graph=None):
+        node = Node.objects.filter(content_type=self.content_type, object_id=self.pk)
+        if graph:
+            node = node.filter(graph=graph)
+        if node.count() == 0:
+            return None
+        elif node.count() > 1:
+            raise MultipleReferencesError, "This object is associated with multiple nodes. Please specify the graph."
+        return node[0]
+    
+    def get_parents(self, graph=None, include_self=False):
+        objs = get_objects_for_node(node=self._node(graph), downwards=False, max_depth=0)
+        if include_self:
+            objs.insert(0, self)
+        return objs
+    
+    def get_children(self, graph=None, include_self=False):
+        objs = get_objects_for_node(node=self._node(graph), max_depth=0)
+        if include_self:
+            objs.insert(0, self)
+        return objs
+    
+    def get_ancestors(self, graph=None, include_self=False):
+        objs = get_objects_for_node(node=self._node(graph), downwards=False)
+        if include_self:
+            objs.insert(0, self)
+        return objs
+    
+    def get_descendants(self, graph=None, include_self=False):
+        objs = get_objects_for_node(node=self._node(graph))
+        if include_self:
+            objs.insert(0, self)
+        return objs
+    
+    def get_parent_count(self, graph=None):
+        return TransitiveClosure.objects.filter(node_to=self._node(
+            graph), depth=0).count()
+    
+    def get_ancestor_count(self, graph=None):
+        return TransitiveClosure.objects.filter(node_to=self._node(
+            graph)).values("node_from", "node_to").distinct().count()
+    
+    def get_child_count(self, graph=None):
+        return TransitiveClosure.objects.filter(node_from=self._node(
+            graph), depth=0).count()
+    
+    def get_descendant_count(self, graph=None):
+        return TransitiveClosure.objects.filter(node_from=self._node(
+            graph)).values("node_from", "node_to").distinct().count()
+    
+    def is_child_node(self, graph=None):
+        return self.get_parent_count(graph) > 0
+    
+    def is_root_node(self, graph=None):
+        return self.get_parent_count(graph) == 0
+    
+    def is_leaf_node(self, graph=None):
+        return self.get_child_count(graph) == 0
+from django.db import models
+from django.utils.translation import ugettext as _
+from django.contrib.contenttypes import generic
+from django.contrib.contenttypes.models import ContentType
+from errors import *
+from utils import create_objects_iterable, flatten
+
+class Graph(models.Model):
+    name = models.TextField(unique=True)
+    
+    class Meta:
+        db_table = "dag_graph"
+    
+    def __unicode__(self):
+        return "Graph: %s" % self.name
+    
+    @property
+    def is_heterogeneous(self):
+        return self.content_types_number > 1
+    
+    @property
+    def content_types_number(self):
+        return Node.objects.filter(graph__id=self.id).values('content_type').distinct().count()
+
+    def combine_graphs(self, other, delete=True):
+        if not isinstance(other, Graph):
+            raise TypeError, "Must combine with Graph instance."
+        
+        # Check for multiple references
+        first_nodes = Node.objects.filter(graph__id=self.id).values_list("id", flat=True)
+        second_nodes = Node.objects.filter(graph__id=other.id).values_list("id", flat=True)
+        if list(set(first_nodes).intersection(set(second_nodes))):
+            raise MultipleReferencesError, "Can't combine graphs with the same node"
+        
+        edges = Edge.objects.filter(graph=other).values("node_from", "node_to")
+        nodes = Node.objects.filter(graph=other)
+        nodes.update(graph=self)
+        
+        # Will delete TransitiveClosure objects on cascade
+        Edge.objects.filter(graph=other).delete()
+        
+        # Cache in dictionary for speedy lookups
+        node_dict = {}
+        for d in edges:
+            if d["node_from"] not in node_dict:
+                d["node_from"] = Node.objects.get(id=d["node_from"])
+            if d["node_to"] not in node_dict:
+                d["node_to"] = Node.objects.get(id=d["node_to"])
+            Edge.objects.create(d["node_from"], d["node_to"], graph=self)
+        
+        if delete:
+            other.delete()
+    
+    @property
+    def root_nodes(self):
+        """Get root nodes for a graph"""
+        return Node.objects.filter(graph=self).exclude(id__in=Edge.objects.filter(
+            node_to__graph=self).values_list('node_to', flat=True))
+
+    @property
+    def root_node_objects(self):
+        """Get objects for root nodes helper"""
+        root_nodes = self.root_nodes
+    
+        tc_query = root_nodes.values_list('graph', 'id', 'content_type', 'object_id')
+    
+        # Create dictionary with content type ids as keys, and objects ids as values
+        content_type_dict = {}
+        for tup in tc_query:
+            content_type_dict.setdefault(tup[2], []).append(tup[3])
+
+        # Get queryset per content type, and aggregate together
+        content_types = ContentType.objects.filter(id__in=content_type_dict.keys())
+        objs = flatten([create_objects_iterable(c, content_type_dict) for c in content_types])
+        return objs
+
+
+class Node(models.Model):
+    graph = models.ForeignKey(Graph)
+    content_type = models.ForeignKey(ContentType, null=True, blank=True)
+    object_id = models.PositiveIntegerField(null=True, blank=True)
+    reference = generic.GenericForeignKey('content_type', 'object_id')
+    
+    class Meta:
+        db_table = "dag_node"
+        unique_together = ("graph", "content_type", "object_id")
+    
+    def save(self, *args, **kwargs):
+        if not self.id and Node.objects.filter(graph=self.graph, 
+                content_type=self.content_type, object_id=self.object_id
+                ).count() and self.content_type and self.object_id:
+            raise MultipleReferencesError, "This node for this graph already exists."
+        super(Node, self).save(*args, **kwargs)
+    
+    def __unicode__(self):
+        if self.object_id:
+            return "Node: %s" % self.reference
+        else:
+            return "<Node object with id %s at address %s>" % (self.id, id(self))
+    
+    def link_to(self, other_node):
+        return Edge.objects.create(self, other_node, graph=self.graph)
+    
+    def link_from(self, other_node):
+        return Edge.objects.create(other_node, self, graph=self.graph)
+
+
+class EdgeManager(models.Manager):
+    def update(self):
+        raise NotImplementedError, "Edges can't be updated - they must be deleted and re-created."
+    
+    def _parse_to_node(self, node, graph):
+        # Make sure objects exist in database
+        try:
+            # Is this the best way to test if an object has been saved?
+            assert node.pk
+        except AssertionError:
+            raise AttributeError, "Can only link nodes already saved in database"
+        
+        if isinstance(node, Node):
+            return node
+        if not hasattr(node, "_node"):
+            # Not registered with DAG mixin
+            if not graph:
+                raise MissingGraphError, "Must specify graph when creating new node."
+            content_type = ContentType.objects.get(app_label=node._meta.app_label, model=node._meta.module_name)
+            return Node.objects.create(graph=graph, object_id=node.id, content_type=content_type)
+        if node._node(graph=graph):
+            return node._node(graph=graph)
+        else:
+            if not graph:
+                raise MissingGraphError, "Must specify graph when creating new node."
+            return Node.objects.create(graph=graph, content_type=node.content_type, object_id=node.id)
+    
+    def create(self, node_from, node_to, graph=None, combine_graphs=False):
+        # Check that nodes are of correct type
+        node_from = self._parse_to_node(node_from, graph)
+        node_to = self._parse_to_node(node_to, graph)
+        
+        if not graph:
+            graph = node_to.graph
+        
+        different_graphs = node_from.graph != node_to.graph
+        if different_graphs and not combine_graphs:
+            raise SeparateGraphsError
+        elif different_graphs:
+            node_from.graph.combine_graphs(node_to.graph)
+            # Reload nodes to reflect new combined graph
+            node_to = Node.objects.get(id=node_to.id)
+            node_from = Node.objects.get(id=node_from.id)
+            graph = node_to.graph
+        
+        return super(EdgeManager, self).create(graph=graph, node_from=node_from, node_to=node_to)
+
+
+class Edge(models.Model):
+    """Edges are references, and don't have any special attributes"""
+    graph = models.ForeignKey(Graph, blank=True)
+    node_from = models.ForeignKey(Node, related_name="from")
+    node_to = models.ForeignKey(Node, related_name="to")
+    
+    def save(self, *args, **kwargs):
+        if self.id:
+            raise NotImplementedError, "Edges can only be created or deleted, not updated."
+        
+        if self.node_from == self.node_to:
+            raise CircularReferenceError
+        
+        if self.node_from.graph != self.node_to.graph:
+            raise SeparateGraphsError
+        
+        if Edge.objects.filter(node_from=self.node_from, node_to=self.node_to, graph=self.graph).count():
+            raise DuplicateEdgeError
+        
+        if not Edge.check_acyclical(self.node_from, self.node_to):
+            raise CircularReferenceError
+        
+        kwargs["force_insert"] = True
+        super(Edge, self).save(*args, **kwargs)
+
+    objects = EdgeManager()
+
+    def __unicode__(self):
+        return "Edge from %s to %s" % (self.node_from, self.node_to)
+
+    class Meta:
+        db_table = "dag_edge"
+        unique_together = ("graph", "node_from", "node_to")
+
+    @staticmethod
+    def check_acyclical(start, end):
+        if start.graph != end.graph:
+            raise SeparateGraphsError
+        if TransitiveClosure.objects.filter(node_from=end, node_to=start, graph=start.graph).count():
+            return False
+        else:
+            return True
+
+
+class TransitiveClosure(models.Model):
+    t_edge_id = models.AutoField(primary_key=True)
+    graph = models.ForeignKey(Graph)
+    node_from = models.ForeignKey(Node, related_name="node_from")
+    node_to = models.ForeignKey(Node, related_name="node_to")
+    # Not foreign keys because create Edge table after syncdb, and shouldn't touch this directly anyway.
+    entry_id = models.PositiveIntegerField()
+    direct_id = models.PositiveIntegerField()
+    exit_id = models.PositiveIntegerField()
+    depth = models.PositiveIntegerField(default=0)
+
+    class Meta:
+        db_table = "dag_transitive_closure"
+    
+    def __unicode__(self):
+        return "Transitive Closure link: %s to %s" % (self.node_from, self.node_to)
+
+To test: ./manage.py test --settings=dag.tests.settings

dag/tests/__init__.py

+from database import DatabaseTestCase
+from content_types import ContentTypeTestCase
+from different_graphs import DifferentGraphsTestCase

dag/tests/content_types.py

+# -*- coding: utf-8 -*-
+from dag.models import *
+from dag.errors import *
+from models import *
+from django.test import TestCase
+
+class ContentTypeTestCase(TestCase):
+
+    def test_basic(self):
+        """
+          Chur
+         /    \
+      Zurich  Bern
+     /    \  /     \
+  Geneva  Lucern  Lausanne
+    |       |
+Winterthur  |
+       \    |
+       Glarus
+        """
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        geneva = City.objects.create(name="Geneva")
+        lucern = City.objects.create(name="Luzern")
+        lausanne = City.objects.create(name="Lausanne")
+        winterthur = City.objects.create(name="Winterthur")
+        glarus = City.objects.create(name="Glarus")
+        chur = City.objects.create(name="Chur")
+        zurich_node = Node(graph=g)
+        zurich_node.reference = zurich
+        zurich_node.save()
+        geneva_node = Node(graph=g)
+        geneva_node.reference = geneva
+        geneva_node.save()
+        bern_node = Node(graph=g)
+        bern_node.reference = bern
+        bern_node.save()
+        lucern_node = Node(graph=g)
+        lucern_node.reference = lucern
+        lucern_node.save()
+        lausanne_node = Node(graph=g)
+        lausanne_node.reference = lausanne
+        lausanne_node.save()
+        winterthur_node = Node(graph=g)
+        winterthur_node.reference = winterthur
+        winterthur_node.save()
+        glarus_node = Node(graph=g)
+        glarus_node.reference = glarus
+        glarus_node.save()
+        chur_node = Node(graph=g)
+        chur_node.reference = chur
+        chur_node.save()
+        Edge.objects.create(chur_node, zurich_node)
+        Edge.objects.create(chur_node, bern_node)
+        Edge.objects.create(winterthur_node, glarus_node)
+        Edge.objects.create(lucern_node, glarus_node)
+        Edge.objects.create(bern_node, lucern_node)
+        Edge.objects.create(zurich_node, geneva_node)
+        Edge.objects.create(bern_node, lausanne_node)
+        Edge.objects.create(geneva_node, winterthur_node)
+        Edge.objects.create(zurich_node, lucern_node)
+        self.assertEqual(lucern.get_parents(), [zurich, bern])
+        self.assertEqual(lucern.get_children(), [glarus,])
+        self.assertEqual(lucern.get_descendants(), [glarus,])
+        self.assertEqual(lucern.get_ancestors(), [zurich, bern, chur])
+        self.assertEqual(glarus.get_descendants(), [])
+        self.assertEqual(glarus.get_children(), [])
+        self.assertEqual(glarus.get_parents(), [lucern, winterthur])
+        self.assertEqual(glarus.get_ancestors(), [lucern, winterthur, zurich, bern, geneva, chur])
+        self.assertEqual(chur.get_ancestors(), [])
+        self.assertEqual(chur.get_descendants(), [zurich, bern, geneva, lucern, lausanne, winterthur, glarus])
+        e = Edge.objects.get(graph=g, node_from=zurich, node_to=lucern)
+        e.delete()
+        self.assertEqual(zurich.get_children(), [geneva,])
+        self.assertEqual(zurich.get_descendants(), [geneva, winterthur, glarus])
+        self.assertFalse(g.is_heterogeneous)
+    
+    def test_multiple_ct_references(self):
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        n1 = Node(graph=g)
+        n1.reference = zurich
+        n1.save()
+        n2 = Node(graph=g)
+        n2.reference = zurich
+        self.assertRaises(MultipleReferencesError, n2.save)
+    
+    def test_heterogeneous_trees_1(self):
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        zurich_node = Node(graph=g)
+        zurich_node.reference = zurich
+        zurich_node.save()
+        bern = City.objects.create(name="Bern")
+        bern_node = Node(graph=g)
+        bern_node.reference = bern
+        bern_node.save()
+        blue = Color.objects.create(name="Blue")
+        blue_node = Node(graph=g)
+        blue_node.reference = blue
+        blue_node.save()
+        green = Color.objects.create(name="Green")
+        green_node = Node(graph=g)
+        green_node.reference = green
+        green_node.save()
+        Edge.objects.create(zurich_node, blue_node)
+        Edge.objects.create(blue_node, bern_node)
+        Edge.objects.create(bern_node, green_node)
+        self.assertEqual(zurich.get_descendants(), [blue, bern, green])
+        self.assertTrue(g.is_heterogeneous)
+    
+    def test_heterogeneous_trees_2(self):
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        blue = Color.objects.create(name="Blue")
+        green = Color.objects.create(name="Green")
+        Edge.objects.create(zurich, blue, graph=g)
+        Edge.objects.create(blue, bern, graph=g)
+        Edge.objects.create(bern, green, graph=g)
+        self.assertEqual(zurich.get_descendants(graph=g), [blue, bern, green])
+        self.assertTrue(g.is_heterogeneous)
+    
+    def test_blank_nodes(self):
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        blank_node = Node.objects.create(graph=g)
+        Edge.objects.create(zurich, blank_node, graph=g)
+        Edge.objects.create(blank_node, bern, graph=g)
+        self.assertEqual(zurich.get_descendants(), [blank_node, bern])
+    
+    def test_edge_creation_without_node_creation(self):
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        geneva = City.objects.create(name="Geneva")
+        Edge.objects.create(zurich, bern, graph=g)
+        Edge.objects.create(bern, geneva, graph=g)
+        self.assertEqual(zurich.get_descendants(), [bern, geneva])
+    
+    def test_counts(self):
+        """
+          Chur
+         /    \
+      Zurich  Bern
+     /    \  /
+  Geneva  Lucern
+        """
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        geneva = City.objects.create(name="Geneva")
+        lucern = City.objects.create(name="Luzern")
+        lausanne = City.objects.create(name="Lausanne")
+        winterthur = City.objects.create(name="Winterthur")
+        glarus = City.objects.create(name="Glarus")
+        chur = City.objects.create(name="Chur")
+        zurich_node = Node(graph=g)
+        zurich_node.reference = zurich
+        zurich_node.save()
+        geneva_node = Node(graph=g)
+        geneva_node.reference = geneva
+        geneva_node.save()
+        bern_node = Node(graph=g)
+        bern_node.reference = bern
+        bern_node.save()
+        lucern_node = Node(graph=g)
+        lucern_node.reference = lucern
+        lucern_node.save()
+        chur_node = Node(graph=g)
+        chur_node.reference = chur
+        chur_node.save()
+        Edge.objects.create(chur_node, zurich_node)
+        Edge.objects.create(chur_node, bern_node)
+        Edge.objects.create(bern_node, lucern_node)
+        Edge.objects.create(zurich_node, geneva_node)
+        Edge.objects.create(zurich_node, lucern_node)
+        
+        self.assertEqual(chur.get_child_count(), 2)
+        self.assertFalse(chur.is_child_node())
+        self.assertFalse(chur.is_leaf_node())
+        self.assertTrue(chur.is_root_node())
+        self.assertEqual(chur.get_descendant_count(), 4)
+        self.assertEqual(chur.get_parent_count(), 0)
+        self.assertEqual(chur.get_ancestor_count(), 0)
+        
+        self.assertEqual(zurich.get_child_count(), 2)
+        self.assertTrue(zurich.is_child_node())
+        self.assertFalse(zurich.is_leaf_node())
+        self.assertFalse(zurich.is_root_node())
+        self.assertEqual(zurich.get_descendant_count(), 2)
+        self.assertEqual(zurich.get_parent_count(), 1)
+        self.assertEqual(zurich.get_ancestor_count(), 1)
+        
+        self.assertEqual(lucern.get_child_count(), 0)
+        self.assertTrue(lucern.is_child_node())
+        self.assertTrue(lucern.is_leaf_node())
+        self.assertFalse(lucern.is_root_node())
+        self.assertEqual(lucern.get_descendant_count(), 0)
+        self.assertEqual(lucern.get_parent_count(), 2)
+        self.assertEqual(lucern.get_ancestor_count(), 3)
+    
+    def test_root_nodes(self):
+        """
+     Zurich  Bern
+     /    \  /
+    Geneva  Lucern
+        """
+        g = Graph.objects.create(name="graph")
+        zurich = City.objects.create(name="Zurich")
+        bern = City.objects.create(name="Bern")
+        geneva = City.objects.create(name="Geneva")
+        lucern = City.objects.create(name="Luzern")
+        zurich_node = Node(graph=g)
+        zurich_node.reference = zurich
+        zurich_node.save()
+        geneva_node = Node(graph=g)
+        geneva_node.reference = geneva
+        geneva_node.save()
+        bern_node = Node(graph=g)
+        bern_node.reference = bern
+        bern_node.save()
+        lucern_node = Node(graph=g)
+        lucern_node.reference = lucern
+        lucern_node.save()
+        Edge.objects.create(bern_node, lucern_node)
+        Edge.objects.create(zurich_node, geneva_node)
+        Edge.objects.create(zurich_node, lucern_node)
+        self.assertEqual(set(g.root_nodes), set([bern_node, zurich_node]))
+        self.assertEqual(set(g.root_node_objects), set([zurich, bern]))
+        

dag/tests/database.py

+# -*- coding: utf-8 -*-
+from dag.models import *
+from dag.errors import *
+from django.test import TestCase
+
+class DatabaseTestCase(TestCase):
+
+    def test_basic(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        Edge.objects.create(n2, n3)
+        self.assertEqual(Node.objects.all().count(), 3)
+        self.assertEqual(Edge.objects.all().count(), 2)
+        self.assertEqual(TransitiveClosure.objects.all().count(), 3)
+    
+    def test_linking_same_node(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        self.assertRaises(CircularReferenceError, Edge.objects.create, n1, n1)
+        self.assertRaises(CircularReferenceError, Edge.objects.create, n1, n1, g)
+    
+    def test_circular_reference(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        Edge.objects.create(n2, n3)
+        self.assertRaises(CircularReferenceError, Edge.objects.create, n3, n1)
+    
+    def test_delete_edge(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        Edge.objects.create(n2, n3)
+        Edge.objects.get(node_from=n1, node_to=n2).delete()
+        self.assertEqual(Node.objects.all().count(), 3)
+        self.assertEqual(Edge.objects.all().count(), 1)
+        self.assertEqual(TransitiveClosure.objects.all().count(), 1)
+    
+    def test_update_edge(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        e = Edge.objects.get(node_from=n1, node_to=n2)
+        e.node_to = n3
+        self.assertRaises(NotImplementedError, e.save)
+    
+    def test_add_duplicate_edge(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        self.assertRaises(DuplicateEdgeError, Edge.objects.create, n1, n2)
+    
+    def test_diamond(self):
+        g = Graph.objects.create(name="graph")
+        n1 = Node.objects.create(graph=g)
+        n2 = Node.objects.create(graph=g)
+        n3 = Node.objects.create(graph=g)
+        n4 = Node.objects.create(graph=g)
+        Edge.objects.create(n1, n2)
+        Edge.objects.create(n1, n3)
+        Edge.objects.create(n2, n4)
+        Edge.objects.create(n3, n4)
+        self.assertEqual(TransitiveClosure.objects.filter(depth=1).count(), 2)
+        self.assertEqual(TransitiveClosure.objects.filter(node_from=n1, node_to=n4).count(), 2)
+        self.assertEqual(TransitiveClosure.objects.all().count(), 6)
+        Edge.objects.get(node_from=n3, node_to=n4).delete()
+        self.assertEqual(TransitiveClosure.objects.filter(depth=1).count(), 1)
+        self.assertEqual(TransitiveClosure.objects.filter(node_from=n1, node_to=n4).count(), 1)
+        self.assertEqual(TransitiveClosure.objects.all().count(), 4)

dag/tests/different_graphs.py

+# -*- coding: utf-8 -*-
+from dag.models import *
+from dag.errors import *
+from django.test import TestCase
+
+class DifferentGraphsTestCase(TestCase):
+    def test_combine_graphs(self):
+        g1 = Graph.objects.create(name="graph 1")
+        n1 = Node.objects.create(graph=g1)
+        g2 = Graph.objects.create(name="graph 2")
+        g2_id = g2.id
+        n2 = Node.objects.create(graph=g2)
+        self.assertRaises(SeparateGraphsError, Edge.objects.create, n1, n2)
+        Edge.objects.create(n1, n2, combine_graphs=True)
+        self.assertEqual(Graph.objects.all().count(), 1)
+        self.assertEqual(Edge.objects.filter(graph__id=g2_id).count(), 0)
+        self.assertEqual(Node.objects.filter(graph__id=g2_id).count(), 0)
+    

dag/tests/models.py

+from django.db import models
+from dag.mixins import GraphMixin
+
+class City(GraphMixin):
+    name = models.TextField(unique=True)
+
+    def __unicode__(self):
+        return self.name
+
+class Color(GraphMixin):
+    name = models.TextField(unique=True)
+    
+    def __unicode__(self):
+        return self.name

dag/tests/settings.py

+DATABASE_ENGINE = 'postgresql_psycopg2'
+DATABASE_NAME = 'django-graph'
+DATABASE_USER = 'django'
+DATABASE_PASSWORD = 'django'
+DATABASE_HOST = 'localhost'
+DATABASE_PORT = '5432'
+
+INSTALLED_APPS = (
+    'dag',
+    'dag.tests',
+    'django.contrib.contenttypes',
+)

dag/transitive_closure.sql

+/*
+
+Schema and triggers for maintaining the transitive closure and acyclicity of
+an incrementally created DAG.
+
+Adapted from http://www.codeproject.com/KB/database/Modeling_DAGs_on_SQL_DBs.aspx
+which references
+
+    Guozhu Dong et al., "Maintaining Transitive Closure of Graphs in SQL"
+    http://www.comp.nus.edu.sg/~wongls/psZ/dlsw-ijit97-16.ps
+
+which in turn references other articles by the same author with proof of
+correctness.
+
+*/
+begin;
+
+create table node (
+    node_id     int,
+    primary key (node_id)
+);
+
+create table edge (
+    node_from   int         references node(node_id)
+                            on update cascade on delete cascade,
+    node_to     int         references node(node_id)
+                            on update cascade on delete cascade,
+    graph_id        int default 42,
+    primary key (node_to, node_from, graph_id),
+    check (node_to != node_from)
+);
+
+create function no_edge_updates() returns trigger as $$
+begin raise exception 'Cannot update edge-table. Delete and insert instead.'; end;
+$$ language plpgsql;
+
+create trigger trig_no_edge_updates before update on edge
+    for each row execute procedure no_edge_updates();
+
+create sequence t_edge_id;
+create table transitive_closure (
+    t_edge_id       int         default nextval('t_edge_id'),
+    graph_id            int ,
+    node_from       int         not null references node(node_id)
+                                         on delete cascade on update cascade,
+    node_to         int         not null references node(node_id)
+                                         on delete cascade on update cascade,
+    -- Auxiliary columns needed to maintain deletions:
+    -- These references justify the surrogate key t_edge_id.
+    -- The lack of foreign keys on them is due to horrible delete-performance.
+    -- It should only be the triggers modifying this table anyway.
+    entry_id        int         not null,
+    direct_id       int         not null,
+    exit_id         int         not null,
+    depth           int         not null default 0,
+    primary key (t_edge_id)
+);
+
+create index idx_trans_from_graph_to on transitive_closure(node_from, graph_id, node_to);
+create index idx_trans_to_graph      on transitive_closure(node_to, graph_id);
+
+create function enforce_acyclicity() returns trigger as
+$$
+begin
+
+if exists(select 1 from transitive_closure where node_to=NEW.node_from and node_from=NEW.node_to and graph_id=NEW.graph_id) then
+    raise exception 'Inserting (%,%) will create a loop.', NEW.node_from, NEW.node_to;
+end if;
+
+return NEW;
+
+end;
+$$ language plpgsql;
+
+create trigger trig_enforce_acyclicity before insert on edge
+    for each row execute procedure enforce_acyclicity();
+
+create function add_implied_edges() returns trigger
+as $$
+declare
+    id int;
+begin
+    id := nextval('t_edge_id');
+    insert into transitive_closure (node_from, node_to, graph_id, entry_id, direct_id, exit_id, t_edge_id) values (new.node_from, new.node_to, new.graph_id, id, id, id, id);
+
+    insert into transitive_closure (direct_id, exit_id, entry_id, node_from, node_to, graph_id, depth)
+
+        -- Incoming edges.
+        select id, id, t_edge_id, node_from, new.node_to, new.graph_id, depth + 1
+            from transitive_closure
+            where node_to = new.node_from
+                  and graph_id=new.graph_id
+
+        union
+
+        -- Outgoing edges.
+        select id, t_edge_id, id, new.node_from, node_to, new.graph_id, depth + 1
+            from transitive_closure
+            where node_from = new.node_to
+                  and graph_id=new.graph_id
+
+        union
+
+        -- Incoming to outgoing.
+        select a.t_edge_id, id, b.t_edge_id, a.node_from, b.node_to, new.graph_id, a.depth + b.depth + 1
+            from transitive_closure a
+                 cross join transitive_closure b
+            where a.node_to = new.node_from and b.node_from = new.node_to
+                  and a.graph_id=new.graph_id and b.graph_id=new.graph_id;
+
+    return null;
+end;
+$$ language plpgsql;
+
+create trigger trig_add_implied_edges after insert on edge
+    for each row execute procedure add_implied_edges();
+
+create function remove_implied_edges() returns trigger
+as $$
+begin
+
+    create temporary table purge_list as
+        -- The direct edge.
+        select direct_id as t_edge_id
+        from transitive_closure
+        where node_from = old.node_from
+              and node_to = old.node_to
+              and graph_id = old.graph_id;
+
+    while true
+    loop
+        insert into purge_list
+        -- Edges dependant of those in the purge list.
+        select t_edge_id
+            from transitive_closure
+            where
+                depth > 0
+                and t_edge_id not in ( select t_edge_id from purge_list )
+                and (
+                    entry_id in ( select t_edge_id from purge_list )
+                    or exit_id in ( select t_edge_id from purge_list )
+                );
+        if not found then
+            exit;
+        end if;
+    end loop;
+
+    delete from transitive_closure
+        where t_edge_id in (
+            select t_edge_id
+            from purge_list
+        );
+
+    drop table purge_list;
+    return null;
+end;
+$$ language plpgsql;
+
+create trigger trig_remove_implied_edges after delete on edge
+    for each row execute procedure remove_implied_edges();
+
+-- (1) -> (2) -> (3) -> (4)  (5)
+--  \--------------------^
+
+insert into node values (1), (2), (3), (4), (5);
+insert into edge values (1,2), (2,3), (3,4), (1,4);
+
+select node_from, node_to, graph_id, depth
+    from transitive_closure
+    order by node_to, node_from, depth;
+
+delete from edge where node_from=2 and node_to=3;
+delete from edge where node_from=1 and node_to=2;
+insert into edge values (5,1);
+
+-- (1) (2) (3)->(4) (5)
+--  \------------^   /
+--  ^---------------/
+
+select node_from, node_to, graph_id, depth
+    from transitive_closure
+    order by node_to, node_from, depth;
+
+delete from node where node_id=1;
+
+-- (2) (3)->(4) (5)
+
+select node_from, node_to, graph_id, depth
+    from transitive_closure
+    order by node_to, node_from, depth;
+
+rollback;
+from django.contrib.contenttypes.models import ContentType
+from django.db import connection, transaction
+from errors import CircularReferenceError, SeparateGraphsError, DuplicateEdgeError
+
+def flatten(x):
+    """Flattened a set of iterables into a single list.
+    
+    From http://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks"""
+    result = []
+    for el in x:
+        if hasattr(el, '__iter__') and not isinstance(el, basestring):
+            result.extend(flatten(el))
+        else:
+            result.append(el)
+    return result
+
+def get_objects_for_node(node, max_depth=None, downwards=True, sort_reverse=False):
+    """Take a node and an optional maximum depth, and return a list of objects linked by generic keys and sorted by depth. 
+
+    Borrows heavily from http://code.google.com/p/soclone/source/browse/trunk/soclone/utils/models.py"""
+    if downwards:
+        # Going 'down' the DAG
+        node_string = 'node_to'
+        node_string_opposite = 'node_from'
+    else:
+        # Going 'up' the DAG
+        node_string = 'node_from'
+        node_string_opposite = 'node_to'
+    cursor = connection.cursor()
+    if max_depth != None:
+        TC_SQL = """select min(tc.depth), n.id, n.content_type_id, n.object_id 
+            from dag_node as n
+            inner join dag_transitive_closure as tc
+            on n.id = tc.%s_id 
+            where tc.%s_id = %%s
+                and tc.depth <= %%s 
+            group by n.id, n.content_type_id, n.object_id, tc.depth
+            order by tc.depth""" % (node_string, node_string_opposite)
+        cursor.execute(TC_SQL, (node.id, max_depth))
+    else:
+        TC_SQL = """select min(tc.depth), n.id, n.content_type_id, n.object_id 
+            from dag_node as n
+            inner join dag_transitive_closure as tc
+            on n.id = tc.%s_id 
+            where tc.%s_id = %%s
+            group by n.id, n.content_type_id, n.object_id, tc.depth
+            order by tc.depth""" % (node_string, node_string_opposite)
+        cursor.execute(TC_SQL, (node.id,))
+    tc_query = cursor.fetchall()
+    
+    # Create dictionary with content type ids as keys, and objects ids as values
+    content_type_dict = {}
+    blank_nodes = []
+    for tup in tc_query:
+        if tup[2]:
+            content_type_dict.setdefault(tup[2], []).append(tup[3]) # CT id, Obj. id
+        else: # Blank content type
+            blank_nodes.append((tup[1], tup[0])) # Node ID, Depth
+
+    # Create dictionary linking (object_id, content_type) to depth
+    depth_dict = dict([((tup[2], tup[3]), tup[0]) for tup in tc_query]) # CT id, Obj. id, Depth
+
+    # Get queryset per content type, and aggregate together
+    content_types = ContentType.objects.filter(id__in=content_type_dict.keys())
+    objs = flatten([create_objects_iterable_with_depth(c, content_type_dict, depth_dict) for c in content_types])
+    
+    # Add nodes without content types
+    from models import Node
+    blank_nodes_queryset = Node.objects.filter(id__in=[x[0] for x in blank_nodes])
+    for index, node in enumerate(blank_nodes_queryset):
+        node._depth = blank_nodes[index][1]
+        objs.append(node)
+    
+    objs = sort_by_attribute(objs, "_depth", sort_reverse=sort_reverse)
+    return objs
+
+def create_objects_iterable_with_depth(content_type, content_type_dict, depth_dict):
+    model = content_type.model_class()
+    objs = model.objects.filter(id__in=content_type_dict[content_type.id])
+    for obj in objs:
+        # Add depth
+        obj._depth = depth_dict[(content_type.id, obj.id)]
+    return objs
+
+def create_objects_iterable(content_type, content_type_dict):
+    model = content_type.model_class()
+    return model.objects.filter(id__in=content_type_dict[content_type.id])
+
+def sort_by_attribute(lst, attr, sort_reverse=False):
+    """Sort a list of objects based on an attribute.
+    
+    Falls back on list index if attributes are equal.
+    Based on http://code.activestate.com/recipes/52230/"""
+    intermed = [ (getattr(obj,attr), index, obj) for index, obj in enumerate(lst)]
+    intermed.sort(reverse = sort_reverse)
+    return [ tup[-1] for tup in intermed ]