Commits

Chad Dombrova committed c090019 Merge

Merge with orm_listener

Comments (0)

Files changed (15)

denormalize/backend/base.py

 from ..context import get_current_context
 
 import logging
+from collections import defaultdict
 from denormalize.models import DocumentCollection
 
 log = logging.getLogger(__name__)
 
-# FIXME:
-from ..orms.django import DjangoConnector
-connector = DjangoConnector()
 
 class BackendBase(object):
 
                 "initialization.".format(name))
         self._registry[name] = self
         self.collections = {}
-        self._listeners = []
+        self._listeners = defaultdict(list)
 
     def register(self, collection):
         """
         self.collections[collection.name] = collection
         self._setup_listeners(collection)
 
-    @staticmethod
-    def _set_affected(ns, collection, instance, affected_set):
-        """Used in pre_* handlers. Annotates an instance with
-        the ids of the root objects it affects, so that we can include
-        these in the final updates in the post_* handlers.
-        """
-        attname = '_denormalize_extra_affected_{0}_{1}'.format(collection.name, ns)
-        current = getattr(instance, attname, set())
-        new = current | affected_set
-        setattr(instance, attname, new)
-
-    @staticmethod
-    def _get_affected(ns, collection, instance):
-        """See _set_affected for an explanation."""
-        attname = '_denormalize_extra_affected_{0}_{1}'.format(collection.name, ns)
-        affected_set = getattr(instance, attname, set())
-        return affected_set
-
     def _get_collection(self, collection_name_or_obj):
         if isinstance(collection_name_or_obj, DocumentCollection):
             name = collection_name_or_obj.name
         :type collection: denormalize.models.DocumentCollection
         """
         dependencies = collection.get_related_models()
-        connector.add_listeners(self, collection, None, collection.model, None)
+        self._add_listeners(collection, None, collection.model, None)
         for filter_path, info in dependencies.items():
-            connector.add_listeners(self, collection, filter_path,
-                                    info['model'], info)
+            self._add_listeners(collection, filter_path, info['model'],
+                                info['relationship'])
+
+    def _add_listeners(self, collection, filter_path, submodel, relationship):
+        """Connect the Django ORM signals to given dependency
+
+        :type collection: denormalize.models.DocumentCollection
+        :param filter_path: ORM filter path (not always the same as the
+            path used in *_related! For example, 'chapter' instead of
+            'chapter_set'!)
+        :type filter_path: basestring, None
+        :param submodel: dependency model to watch for changes
+        :type submodel: denormalize.orms.ModelInspector
+        :param relationship: relationship information
+        :type relationship: denormalize.orms.RelationshipProperty
+        """
+        # TODO: this does not handle changing foreignkeys well. The object
+        #       will be added to the new root, but not deleted from the old
+        #       root object. Maybe solve by also adding a pre_save? The
+        #       database will still contain the old connection.
+        # TODO: Consider moving all of this to a Collection or something
+        #       separate and just listen to signals from there.
+        collection_listener = collection.model.listener
+        # FIXME: pass inspector instead of _model
+        listener = collection_listener(self, collection, submodel._model,
+                                       filter_path, relationship)
+        listener.connect()
+        self._listeners[collection.name].append(listener)
 
     def _queue_deleted(self, collection, doc_id):
         """Queue a deleted notification"""

denormalize/models.py

             filter_path = []
             for accessor in pathstring.split('__'):
                 log.debug("get_related_models: %s (%s)", pathstring, accessor)
-                fieldname, info = model.get_field_info(accessor)
+                rel = model.relation_map()[accessor]
 
-                model = info['model']
+                # set model for next iteration of accessor
+                model = rel.model if rel.model != model else rel.related_model
                 log.debug("get_related_models: %s (%s) => %s",
                           pathstring, accessor, model.name)
+                info = {}
                 info['path'] = pathstring
+                info['model'] = model
+                info['relationship'] = rel
+                # these are no longer needed since they are handled by the
+                # RelationProperty
+                if rel.association_model is not None:
+                    info['through'] = rel.association_model
+                    info['m2m'] = True
+                else:
+                    info['m2m'] = False
 
                 # We also store all intermediate paths # FIXME: include in tests
-                filter_path.append(fieldname)
+                filter_path.append(rel.name)
                 filter_pathstring = '__'.join(filter_path)
 
                 model_info[filter_pathstring] = info
-                # set model for next iteration of accessor
 
         return model_info
 
         return qs.all()
 
     def dump_collection(self):
-        for obj in self.queryset():
-            yield self.dump(obj)
+        # for obj in self.queryset():
+        #    yield self.dump(obj)
+
+        # FIXME: this does not have prefetching yet!!!
+        for obj in self.model.get_native_objects():
+            yield self.dump(obj)            
 
     def dump_id(self, root_pk):
         return self.dump(self.model.get_native_object(root_pk))
                     pass
 
         # *-to-one relationships
-        for relname, field_id in model.iter_relations(obj, 'has one'):
+        rels1 = list(model.iter_relations(obj, 'has one'))
+        rels2 = list(model.iter_relations(None, 'has one'))
+        assert len(rels1) == len(rels2)
+        for relation in model.iter_relations(obj, 'has one'):
+            relname = relation.accessor
             if relname not in excluded:
                 if relname in fields_to_follow:
                     try:
                         data[relname] = self.dump_obj(value.model,
                                                       value.instance,
                                                       path + [relname])
-                elif field_id is not None:
+                elif relation.foreign_key_field is not None:
                     # field_id is set if the foreign key is local to obj
                     try:
-                        data[field_id] = model.get_value(obj, field_id)
+                        value = model.get_value(obj, relation.foreign_key_field)
                     except Skip:
                         pass
+                    else:
+                        data[relation.foreign_key_field] = value
 
         # *-to-many relationships
-        for relname, _ in model.iter_relations(obj, 'has many'):
+        rels1 = list(model.iter_relations(obj, 'has many'))
+        rels2 = list(model.iter_relations(None, 'has many'))
+        assert len(rels1) == len(rels2)
+        for relation in model.iter_relations(obj, 'has many'):
+            relname = relation.accessor
             if relname in included_relations:
                 try:
                     values = model.get_value(obj, relname)

denormalize/orms/base.py

+import logging
+
+log = logging.getLogger(__name__)
 
 
 class Skip(Exception):
     """Skip serializing a value"""
     pass
 
+
 class MultipleValues(dict):
     """Used to rename keys, or split data into multiple keys"""
     pass
 
+
 class ModelInstancePair(object):
     def __init__(self, model, instance):
         self.model = model
         self.instance = instance
 
 
-# class Property(object):
-#     def __init__(self, model, obj):
-#         self.name = model
-#         self.obj = obj
+class Property(object):
+    def __init__(self, native_obj, name, model, accessor=None):
+        self.model = model
+        self.obj = native_obj
+        self.name = name
+        self.accessor = accessor if accessor else name
 
-# class FieldProperty(Property):
-#     pass
 
-# class RelationProperty(Property):
-#     def __init__(self, model, obj, cardinality):
-#         super(RelationProperty, self).__init__(model, obj)
-#         self.cardinality = cardinality
+class FieldProperty(Property):
+    pass
 
-# class OneToOneRelation(RelationProperty):
-#     pass
+
+class RelationProperty(Property):
+    # FIXME: 
+    # Some of the values on this class are relative to the model that returns it:
+    #   - accessor
+    #   - foreign_key_field
+    #   - cardinality
+    # Others are absolute. i.e. based on the model that "owns" of the relationship:
+    #   - model
+    #   - related_model
+    # Consider making these last two relative. i.e. `model` should always be 
+    # the model to which `accessor` and `cardinality` apply, even if
+    # `related_model` "owns" the relationship 
+    # 
+    # Also, consider adding a boolean to indicate whether `model` "owns" the
+    # relationship (potential names: `is_owner`, `is_direct`, `model_is_owner`)
+    def __init__(self, native_obj, name, cardinality, model, related_model,
+                 association_model=None, foreign_key_field=None, accessor=None):
+        super(RelationProperty, self).__init__(native_obj, name, model, accessor)
+        self.related_model = related_model
+        self.association_model = association_model
+        self.cardinality = cardinality
+        self.foreign_key_field = foreign_key_field
+
+    def __repr__(self):
+        return '%s(%r, %r, %r)' % (self.__class__.__name__, self.obj, self.name,
+                                   self.cardinality)
+
 
 class ModelInspector(object):
     _inspectors = []
+    collection_listener = None
 
     def __init__(self, model):
         self._model = model
     def register(cls):
         cls._inspectors.append(cls)
 
+    def relation_map(self, by_accessor=True):
+        if by_accessor:
+            return dict((rel.accessor, rel) for rel in self.iter_relations())
+        else:
+            return dict((rel.name, rel) for rel in self.iter_relations())
     # TODO: move over base methods from DjangoModelInspector when locked down
 
+    def __eq__(self, other):
+        return self._model == other._model
+
+
+class CollectionListener(object):
+    def __init__(self, backend, collection, model, filter_path=None,
+                 relation=None):
+        self.backend = backend
+        self.collection = collection
+        self.model = model
+        self.filter_path = filter_path
+        self.relation = relation
+
+    def connect(self):
+        raise NotImplementedError()
+
+    def disconnect(self):
+        raise NotImplementedError()
+
+    def _get_affected_instance_ids(self, instance):
+        """"Find all affected root model instance ids.
+        These are the objects that reference the changed instance
+        """
+        if self.filter_path:
+            # FIXME: taking collection.queryset out of the equation for now
+            # since it is not cross-orm compatible
+            # Submodel
+            affected = self.collection.model.get_objects_affected_by(self.filter_path, instance)
+        else:
+            # FIXME: use inspector to get primary key
+            # This is the root model
+            affected = [instance.id]
+        return set(affected)
+
+    def _set_affected(self, ns, instance, affected_set):
+        """Used in pre_* handlers. Annotates an instance with
+        the ids of the root objects it affects, so that we can include
+        these in the final updates in the post_* handlers.
+        """
+        attname = '_denormalize_extra_affected_{0}_{1}'.format(self.collection.name, ns)
+        current = getattr(instance, attname, set())
+        new = current | affected_set
+        setattr(instance, attname, new)
+
+    def _get_affected(self, ns, instance):
+        """See _set_affected for an explanation."""
+        attname = '_denormalize_extra_affected_{0}_{1}'.format(self.collection.name, ns)
+        affected_set = getattr(instance, attname, set())
+        return affected_set
+
+    # -- callbacks
+
+    def pre_update(self, instance):
+        log.debug('pre_update: collection %s, %s (%s) with id %s',
+                  self.collection.name, self.model.__name__,
+                  self.filter_path or '^', instance.id)
+
+        if self.filter_path:
+            affected = self._get_affected_instance_ids(instance)
+            log.debug("pre_update:   affected documents: %r", affected)
+
+            # Instead of calling the update functions directly, we make
+            # sure it is queued for after the operation. Otherwise
+            # the changes are not reflected in the database yet.
+            self._set_affected('update', instance, affected)
+
+    def post_update(self, instance):
+        log.debug('post_update: collection %s, %s (%s) with id %s',
+                  self.collection.name, self.model.__name__,
+                  self.filter_path or '^', instance.id)
+
+        affected = self._get_affected_instance_ids(instance)
+        log.debug("post_update:   affected documents: %r", affected)
+
+        # This can be passed by the pre_save handler
+        extra_affected = self._get_affected('update', instance)
+        if extra_affected:
+            affected.update(extra_affected)
+
+        for doc_id in self.collection.map_affected(affected):
+            self.backend._queue_changed(self.collection, doc_id)
+
+    def pre_insert(self, instance):
+        # Not used, does not need to be called/connected
+        pass
+
+    def post_insert(self, instance):
+        log.debug('post_insert: collection %s, %s (%s) with id %s',
+                  self.collection.name, self.model.__name__,
+                  self.filter_path or '^', instance.id)
+
+        affected = self._get_affected_instance_ids(instance)
+        for doc_id in self.collection.map_affected(affected):
+            self.backend._queue_added(self.collection, doc_id)
+
+    def pre_delete(self, instance):
+        log.debug('pre_delete: collection %s, %s (%s) with id %s being deleted',
+                  self.collection.name, self.model.__name__,
+                  self.filter_path or '^', instance.id)
+
+        affected = self._get_affected_instance_ids(instance)
+        log.debug("pre_delete:   affected documents: %r", affected)
+
+        # Instead of calling the update functions directly, we make
+        # sure it is queued for after the operation. Otherwise
+        # the changes are not reflected in the database yet.
+        self._set_affected('delete', instance, affected)
+
+    def post_delete(self, instance):
+        log.debug('post_delete: collection %s, %s (%s) with id %s was deleted',
+                  self.collection.name, self.model.__name__,
+                  self.filter_path or '^', instance.id)
+
+        affected = self._get_affected('delete', instance)
+        if self.filter_path:
+            queue = self.backend._queue_changed
+        else:
+            queue = self.backend._queue_deleted
+        for doc_id in self.collection.map_affected(affected):
+            queue(self.collection, doc_id)
+
+
 def inspector(model):
 
     if isinstance(model, ModelInspector):
     for insp in ModelInspector._inspectors:
         if insp.is_compatible_class(model):
             return insp(model)
+

denormalize/orms/django.py

 from __future__ import absolute_import
-from .base import Skip, ModelInspector, ModelInstancePair
+from .base import Skip, ModelInspector, ModelInstancePair, RelationProperty, \
+    FieldProperty, CollectionListener
 from django.db import models
 from django.db.models import signals
 
     def is_compatible_class(model_class):
         return issubclass(model_class, models.Model)
 
+    # FIXME: move to CollectionModel?
     def get_native_object(self, primary_key):
         return self._model.objects.get(pk=primary_key)
 
-    def get_field_info(self, accessor):
+    # FIXME: move to CollectionModel?
+    def get_native_objects(self):
+        return self._model.objects.all()
+
+    # FIXME: move to CollectionModel?
+    # FIXME: if it does not affect performance, return instances here, and get
+    # the ids from them in the calling function
+    def get_objects_affected_by(self, join_path, instance):
+        filt = {join_path: instance}
+        # affected = self.collection.queryset(prefetch=False).filter(
+        #     **filt).values_list('id', flat=True)
+
+        return self._model.objects.all().filter(
+            **filt).values_list('id', flat=True)
+
+    def get_relation_models(self, accessor):
         if accessor in self._field_info_cache:
             return self._field_info_cache[accessor]
 
         # Most of the time these are equal
         fieldname = accessor
-        info = {}
 
         # WORKAROUND: Sometimes for unknown reasons the meta
         # information has not been fully updated. We check for
 
         # Get field info
         # TODO: show list of valid names if the field is invalid
-        try:
-            field, field_model, direct, m2m = \
-                self._model._meta.get_field_by_name(fieldname)
-        except models.FieldDoesNotExist:
-            if accessor.endswith('_set'):
-                # If no related_name is specified, the field name will
-                # be 'foo', while the queryset is accessible through
-                # 'foo_set'. If the related_name has been set, these
-                # will be the same.
-                fieldname = accessor[:-4] # strip '_set'
-                field, field_model, direct, m2m = \
-                    self._model._meta.get_field_by_name(fieldname)
-            else:
-                raise
+        # try:
+        field, field_model, is_owner, m2m = \
+            self._model._meta.get_field_by_name(fieldname)
+        # except models.FieldDoesNotExist:
+        #     if accessor.endswith('_set'):
+        #         # If no related_name is specified, the field name will
+        #         # be 'foo', while the queryset is accessible through
+        #         # 'foo_set'. If the related_name has been set, these
+        #         # will be the same.
+        #         fieldname = accessor[:-4] # strip '_set'
+        #         field, field_model, is_owner, m2m = \
+        #             self._model._meta.get_field_by_name(fieldname)
+        #     else:
+        #         raise
 
-        # We want to get the next model in our iteration
-        if direct and not m2m:
+        if is_owner and not m2m:
             # (<related.ForeignKey: publisher>, None, True, False)
-            model = field.related.parent_model
-        elif direct and m2m:
+            model = self
+            related_model = self.__class__(field.related.parent_model)
+            association_model = None
+        elif is_owner and m2m:
             # (<related.ManyToManyField: authors>, None, True, True)
             # (<related.ManyToManyField: tags>, None, True, True)
-            model = field.rel.to
-            info['through'] = self.__class__(field.rel.through)
-        elif not direct and not m2m:
+            model = self
+            related_model = self.__class__(field.rel.to)
+            association_model = field.rel.through
+        elif not is_owner and not m2m:
             # (<RelatedObject: tests:extrabookinfo related to book>, None, False, False)
             # (<RelatedObject: tests:chapter related to book>, None, False, False)
             # (<RelatedObject: tests:publisherlink related to publisher>, None, False, False)
-            model = field.model
-        elif not direct and m2m:
+            model = self.__class__(field.model)
+            related_model = self
+            association_model = None
+        elif not is_owner and m2m:
             # (<RelatedObject: tests:category related to books>, None, False, True)
-            model = field.model
-            info['through'] = self.__class__(field.field.rel.through)
+            model = self.__class__(field.model)
+            related_model = self
+            association_model = field.field.rel.through
         else:
             # Cannot happen, we covered all four cases
             raise AssertionError("Impossible")
 
-        info['model'] = self.__class__(model)
-        info['m2m'] = m2m
-        info['direct'] = direct
-        self._field_info_cache[accessor] = fieldname, info
-        return fieldname, info
+        if association_model is not None:
+            association_model = self.__class__(association_model)
+        self._field_info_cache[accessor] = (model, related_model, association_model)
+        return model, related_model, association_model
 
     # -- fields
 
     #                                                      source):
     #         yield relname, field_id
 
-    def _iter_relations(self, obj, cardinality):
+    def iter_relations(self, obj=None, cardinality='any'):
         assert cardinality in ('has one', 'has many', 'any')
 
         if obj is not None:
         else:
             meta = self._model._meta
 
-        # FIXME: look into combining these loops
+        # FIXME: cache all of this
+        related = meta.get_all_related_objects_with_model()
         if cardinality in ('has one', 'any'):
-            for rel, model in meta.get_all_related_objects_with_model():
+            for rel, model in related:
                 if isinstance(rel.field, models.OneToOneField):
-                    yield rel.get_accessor_name(), None, rel
+                    accessor = rel.get_accessor_name()
+                    name = rel.field.related_query_name()
+                    yield RelationProperty(
+                        rel, name, 'has one',
+                        *self.get_relation_models(accessor),
+                        accessor=accessor)
 
             for rel, field_id, field in self._iter_foreign_keys(obj):
-                yield rel, field_id, field
+                yield RelationProperty(
+                    field, rel, 'has one',
+                    *self.get_relation_models(rel),
+                    foreign_key_field=field_id)
 
         if cardinality in ('has many', 'any'):
-            for rel, model in meta.get_all_related_objects_with_model():
+            for rel, model in related:
                 if not isinstance(rel.field, models.OneToOneField):
-                    yield rel.get_accessor_name(), None, rel
+                    accessor = rel.get_accessor_name()
+                    name = rel.field.related_query_name()
+                    yield RelationProperty(
+                        rel, name, 'has many',
+                        *self.get_relation_models(name),
+                        accessor=accessor)
 
             for rel, model in meta.get_all_related_m2m_objects_with_model():
-                yield rel.get_accessor_name(), None, rel
+                accessor = rel.get_accessor_name()
+                name = rel.field.related_query_name()
+                yield RelationProperty(
+                    rel, name, 'has many',
+                    *self.get_relation_models(name),
+                    accessor=accessor)
 
             for rel, model in meta.get_m2m_with_model():
-                yield rel.name, None, rel
-
-    def iter_relations(self, obj, cardinality='any'):
-        for relname, field_id, _ in self._iter_relations(obj, cardinality):
-            yield relname, field_id
+                yield RelationProperty(
+                    rel, rel.name, 'has many',
+                    *self.get_relation_models(rel.name))
 
     def get_value(self, obj, accessor):
         try:
         else:
             return value
 
-class DjangoConnector(object):
-    def __init__(self):
-        # Keep references to listeners to prevent garbage collection
-        self._listeners = []
 
-    @staticmethod
-    def _get_affected_instances(filter_path, collection, instance):
-        # Find all affected root model instance ids
-        # These are the objects that reference the changed instance
-        if filter_path:
-            # Submodel
-            filt = {filter_path: instance}
-            affected = set(collection.queryset(prefetch=False).filter(
+class DjangoCollectionListener(CollectionListener):
+    def connect(self):
+        # Connect to the save signal
+        signals.pre_save.connect(self.pre_save, sender=self.model)
+        signals.post_save.connect(self.post_save, sender=self.model)
+        signals.pre_delete.connect(self.pre_delete, sender=self.model)
+        signals.post_delete.connect(self.post_delete, sender=self.model)
+        # M2M handling
+        if self.relation and self.relation.association_model:
+            through_model = self.relation.association_model._model
+            signals.m2m_changed.connect(self.m2m_changed,
+                                        sender=through_model)
+
+    # -- callbacks
+
+    def pre_save(self, sender, instance, raw, **kwargs):
+        # Used to detect FK changes
+        created_guess = not instance.pk
+
+        # From the Django docs: "True if the model is saved exactly as
+        # presented (i.e. when loading a fixture). One should not
+        # query/modify other records in the database as the database
+        # might not be in a consistent state yet."
+        if raw:
+            log.warn("pre_save: raw=True, so no document sync performed!")
+            return
+
+        # We only need this to monitor FK *changes*, so no need to do
+        # anything for new objects.
+        if created_guess:
+            log.debug("pre_save:   looks new, no action needed")
+            # note needed:
+            # self.post_insert()
+            return
+
+        self.pre_update(instance)
+
+    def post_save(self, sender, instance, created, raw, **kwargs):
+        # From the Django docs: "True if the model is saved exactly as
+        # presented (i.e. when loading a fixture). One should not
+        # query/modify other records in the database as the database
+        # might not be in a consistent state yet."
+        if raw:
+            log.warn("post_save: raw=True, so no document sync performed!")
+            return
+
+        if created:
+            self.post_insert(instance)
+        else:
+            self.post_update(instance)
+
+    def pre_delete(self, sender, instance, **kwargs):
+        super(DjangoCollectionListener, self).pre_delete(instance)
+
+    def post_delete(self, sender, instance, **kwargs):
+        super(DjangoCollectionListener, self).post_delete(instance)
+
+    def m2m_changed(self, sender, instance, action, reverse, model, pk_set, **kwargs):
+        """
+        :param sender: The intermediate model class describing the
+            ManyToManyField. This class is automatically created when a
+            many-to-many field is defined; you can access it using the through
+            attribute on the many-to-many field.
+        :param instance: The instance whose many-to-many relation is updated.
+            This can be an instance of the sender, or of the class the
+            ManyToManyField is related to.
+        :param action: A string indicating the type of update that is done on
+            the relation.
+        :param reverse: Indicates which side of the relation is updated (i.e.,
+            if it is the forward or reverse relation that is being modified).
+        :param model: The class of the objects that are added to, removed from
+            or cleared from the relation.
+            e.g. `Book` in `category.books.add(book)`
+        :param pk_set: For the pre_add, post_add, pre_remove and post_remove
+            actions, this is a set of primary key values that have been added
+            to or removed from the relation. For the pre_clear and post_clear
+            actions, this is None.
+
+        For example, given::
+
+            category.books.add(book1)
+            category.books.add(book2)
+
+        The function arguments would be:
+            - instance: category
+            - model: Book
+            - pk_set: set([book1.pk, book2.pk])
+            - reverse: False
+
+        Given::
+
+            book1.category_set.add(category)
+
+        The function arguments would be:
+            - instance: book1
+            - model: Category
+            - pk_set: set([category.pk])
+            - reverse: True?
+        """
+        # An m2m change affects both sides: if an Author is added to a Book,
+        # the Author will also have a new book on its side. It does not
+        # matter however in which direction we are looking at this relation,
+        # since any check on the instance passed will lead us to the root
+        # object that must be marked as changed.
+
+        # Opposite side of `model`
+        instance_model = instance.__class__
+
+        # FIXME: debug level
+        log.debug('m2m_changed: collection %s, %s (%s) with %s.id=%s m2m '
+            '%s on %s side (reverse=%s), pk_set=%s',
+            self.collection.name, self.model.__name__, self.filter_path or '^',
+            instance_model.__name__,
+            instance.id, action, model.__name__, reverse, pk_set)
+
+        # Optimization: if either side is our root object, send out changes
+        # without querying.
+        if instance_model is self.collection.model._model:
+            self.backend._queue_changed(self.collection, instance.id)
+            return
+
+        elif model is self.collection.model._model and pk_set:
+            for pk in pk_set:
+                self.backend._queue_changed(self.collection, pk)
+            return
+
+        # Otherwise figure out which side is equal to the model we are
+        # registering handlers for, and use that for querying.
+        if instance_model is self.model:
+            affected = self._get_affected_instance_ids(instance)
+            for doc_id in self.collection.map_affected(affected):
+                self.backend._queue_changed(self.collection, doc_id)
+            return
+
+        elif model is self.model and pk_set:
+            # FIXME: this is wrong! setting tags does not affect all other
+            #        book that have the same tag!
+            filt = {'{0}__pk__in'.format(self.filter_path): pk_set}
+            affected = set(self.collection.queryset(prefetch=False).filter(
                 **filt).values_list('id', flat=True))
-        else:
-            # This is the root model
-            affected = set([instance.id])
-        return affected
+            for doc_id in self.collection.map_affected(affected):
+                self.backend._queue_changed(self.collection, doc_id)
+            return
 
-    def add_listeners(self, backend, collection, filter_path, submodel, info):
-        """Connect the Django ORM signals to given dependency
+        # FIXME: how to handle the clear signals (pre and post) in this form?
+        # Tag (publisher__tags) with Publisher.id=1 m2m pre_clear on Tag side (reverse=False), pk_set=None
+        # We don't have any ID for Tag, only for the other side
+        # --> strip last item from filter path
 
-        :type collection: denormalize.models.DocumentCollection
-        :param filter_path: ORM filter path (not always the same as the
-            path used in *_related! For example, 'chapter' instead of
-            'chapter_set'!)
-        :type filter_path: basestring, None
-        :param submodel: dependency model to watch for changes
-        :type submodel: django.db.models.Model
-        :param info: dict as returned by `DocumentCollection.get_model_info`
-        :type info: dict, None
-        """
-        # TODO: this does not handle changing foreignkeys well. The object
-        #       will be added to the new root, but not deleted from the old
-        #       root object. Maybe solve by also adding a pre_save? The
-        #       database will still contain the old connection.
-        # TODO: Consider moving all of this to a Collection or something
-        #       separate and just listen to signals from there.
+        # The type of action does not matter, we simply do a lookup for
+        # given model and
+        #affected = getattr(instance, '_denormalize_m2m_affected', set())
+        #if filter_path:
+        #    for doc_id in collection.map_affected(affected):
+        #        backend._queue_changed(collection, doc_id)
+        #else:
+        #    for doc_id in collection.map_affected(set([instance.id])):
+        #        backend._queue_deleted(collection, doc_id)
 
-        submodel = submodel._model
-
-        def pre_save(sender, instance, raw, **kwargs):
-            # Used to detect FK changes
-            created_guess = not instance.pk
-            log.debug('pre_save: collection %s, %s (%s) with id %s %s',
-                collection.name, submodel.__name__, filter_path or '^',
-                instance.id, 'added' if created_guess else 'changed')
-
-            # From the Django docs: "True if the model is saved exactly as
-            # presented (i.e. when loading a fixture). One should not
-            # query/modify other records in the database as the database
-            # might not be in a consistent state yet."
-            if raw:
-                log.warn("pre_save: raw=True, so no document sync performed!")
-                return
-
-            # We only need this to monitor FK *changes*, so no need to do
-            # anything for new objects.
-            if created_guess:
-                log.debug("pre_save:   looks new, no action needed")
-                return
-
-            affected = self._get_affected_instances(filter_path, collection,
-                                                    instance)
-            if filter_path:
-                log.debug("pre_save:   affected documents: %r", affected)
-
-                # Instead of calling the update functions directly, we make
-                # sure it is queued for after the operation. Otherwise
-                # the changes are not reflected in the database yet.
-                backend._set_affected('save', collection, instance, affected)
-
-
-        def post_save(sender, instance, created, raw, **kwargs):
-            log.debug('post_save: collection %s, %s (%s) with id %s %s',
-                collection.name, submodel.__name__, filter_path or '^',
-                instance.id, 'created' if created else 'changed')
-
-            # From the Django docs: "True if the model is saved exactly as
-            # presented (i.e. when loading a fixture). One should not
-            # query/modify other records in the database as the database
-            # might not be in a consistent state yet."
-            if raw:
-                log.warn("post_save: raw=True, so no document sync performed!")
-                return
-
-            affected = self._get_affected_instances(filter_path, collection,
-                                                    instance)
-            log.debug("post_save:   affected documents: %r", affected)
-
-            # This can be passed by the pre_save handler
-            extra_affected = backend._get_affected('save', collection, instance)
-            if extra_affected:
-                affected |= extra_affected
-
-            if filter_path or not created:
-                for doc_id in collection.map_affected(affected):
-                    backend._queue_changed(collection, doc_id)
-            else:
-                for doc_id in collection.map_affected(set([instance.id])):
-                    backend._queue_added(collection, doc_id)
-
-
-        def pre_delete(sender, instance, **kwargs):
-            log.debug('pre_delete: collection %s, %s (%s) with id %s being deleted',
-                collection.name, submodel.__name__, filter_path or '^',
-                instance.id)
-
-            affected = self._get_affected_instances(filter_path, collection,
-                                                    instance)
-            log.debug("pre_delete:   affected documents: %r", affected)
-
-            # Instead of calling the update functions directly, we make
-            # sure it is queued for after the operation. Otherwise
-            # the changes are not reflected in the database yet.
-            backend._set_affected('delete', collection, instance, affected)
-
-        def post_delete(sender, instance, **kwargs):
-            log.debug('post_delete: collection %s, %s (%s) with id %s was deleted',
-                collection.name, submodel.__name__, filter_path or '^',
-                instance.id)
-            affected = backend._get_affected('delete', collection, instance)
-            if filter_path:
-                for doc_id in collection.map_affected(affected):
-                    backend._queue_changed(collection, doc_id)
-            else:
-                for doc_id in collection.map_affected(set([instance.id])):
-                    backend._queue_deleted(collection, doc_id)
-
-
-        def m2m_changed(sender, instance, action, reverse, model, pk_set, **kwargs):
-            # An m2m change affects both sides: if an Author is added to a Book,
-            # the Author will also have a new book on its side. It does not
-            # matter however in which direction we are looking at this relation,
-            # since any check on the instance passed will lead us to the root
-            # object that must be marked as changed.
-
-            # Opposite side of `model`
-            instance_model = instance.__class__
-
-            # FIXME: debug level
-            log.info('m2m_changed: collection %s, %s (%s) with %s.id=%s m2m '
-                '%s on %s side (reverse=%s), pk_set=%s',
-                collection.name, submodel.__name__, filter_path or '^',
-                instance_model.__name__,
-                instance.id, action, model.__name__, reverse, pk_set)
-
-            # Optimization: if either side is our root object, send out changes
-            # without querying.
-            if instance_model is collection.model._model:
-                backend._queue_changed(collection, instance.id)
-                return
-
-            elif model is collection.model._model and pk_set:
-                for pk in pk_set:
-                    backend._queue_changed(collection, pk)
-                return
-
-            # Otherwise figure out which side is equal to the model we are
-            # registering handlers for, and use that for querying.
-            if instance_model is submodel:
-                filt = {filter_path: instance}
-                affected = set(collection.queryset(prefetch=False).filter(
-                    **filt).values_list('id', flat=True))
-                for doc_id in collection.map_affected(affected):
-                    backend._queue_changed(collection, doc_id)
-                return
-
-            elif model is submodel and pk_set:
-                # FIXME: this is wrong! setting tags does not affect all other
-                #        book that have the same tag!
-                filt = {'{0}__pk__in'.format(filter_path): pk_set}
-                affected = set(collection.queryset(prefetch=False).filter(
-                    **filt).values_list('id', flat=True))
-                for doc_id in collection.map_affected(affected):
-                    backend._queue_changed(collection, doc_id)
-                return
-
-            # FIXME: how to handle the clear signals (pre and post) in this form?
-            # Tag (publisher__tags) with Publisher.id=1 m2m pre_clear on Tag side (reverse=False), pk_set=None
-            # We don't have any ID for Tag, only for the other side
-            # --> strip last item from filter path
-
-            # The type of action does not matter, we simply do a lookup for
-            # given model and
-            #affected = getattr(instance, '_denormalize_m2m_affected', set())
-            #if filter_path:
-            #    for doc_id in collection.map_affected(affected):
-            #        backend._queue_changed(collection, doc_id)
-            #else:
-            #    for doc_id in collection.map_affected(set([instance.id])):
-            #        backend._queue_deleted(collection, doc_id)
-
-
-        # We need to keep a reference, because signal connections are weak
-        self._listeners.append(pre_save)
-        self._listeners.append(post_save)
-        self._listeners.append(pre_delete)
-        self._listeners.append(post_delete)
-        # Connect to the save signal
-        signals.pre_save.connect(pre_save, sender=submodel)
-        signals.post_save.connect(post_save, sender=submodel)
-        signals.pre_delete.connect(pre_delete, sender=submodel)
-        signals.post_delete.connect(post_delete, sender=submodel)
-        # M2M handling
-        if info and info['m2m']:
-            backend._listeners.append(m2m_changed)
-            through_model = info['through']._model
-            signals.m2m_changed.connect(m2m_changed, sender=through_model)
-
+DjangoModelInspector.listener = DjangoCollectionListener
 DjangoModelInspector.register()

denormalize/orms/sqlalchemy.py

 from sqlalchemy import types as sqltypes
 from sqlalchemy.orm.exc import UnmappedClassError
 from sqlalchemy.orm.collections import InstrumentedList
+from sqlalchemy import event
 
-from .base import Skip, ModelInspector, ModelInstancePair
+from .base import Skip, ModelInspector, ModelInstancePair, RelationProperty, \
+    FieldProperty, CollectionListener
 import logging
 
 log = logging.getLogger(__name__)
     except UnmappedClassError:
         return False
 
+
 class SqlAlchemyModelInspector(ModelInspector):
+    # FIXME: figure out how to get the session to the inspector:
+    _session = None
+
     def __init__(self, model):
         super(SqlAlchemyModelInspector, self).__init__(model)
         self._mapper = class_mapper(model)
-        self._field_info_cache = None
-        self._relation_info_cache = None
+        self._fieldname_to_relation = None
+        self._relation_name_to_field_name = None
+        self._field_info_cache = {}
 
     @property
     def table_name(self):
     def is_compatible_class(model_class):
         return _is_model_class(model_class)
 
+    # FIXME: move to CollectionModel?
     def get_native_object(self, primary_key):
-        return self._model.objects.get(pk=primary_key)
+        q = self._session.query(self._model).filter_by(id=primary_key)
+        return q.one()
 
-    def get_field_info(self, accessor):
-        if accessor in self._field_info_cache:
-            return self._field_info_cache[accessor]
+    # FIXME: move to CollectionModel?
+    def get_native_objects(self):
+        return self._session.query(self._model).all()
 
-        # Most of the time these are equal
-        fieldname = accessor
-        info = {}
+    # FIXME: move to CollectionModel?
+    # FIXME: if it does not affect performance, return instances here, and get
+    # the ids from them in the calling function
+    def get_objects_affected_by(self, join_path, instance):
+        # FIXME: look into using the session from the instance:
+        # session = sqlalchemy.inspect(inspect).session
+        join_paths = join_path.split('__')
+        q = self._session.query(self._model.id).join(*join_paths)
+        filt = {'id': instance.id}
+        return [x[0] for x in q.filter_by(**filt).all()]
 
-        # Get field info
-        # TODO: show list of valid names if the field is invalid
-        try:
-            field, field_model, direct, m2m = \
-                self._model._meta.get_field_by_name(fieldname)
-        except models.FieldDoesNotExist:
-            if accessor.endswith('_set'):
-                # If no related_name is specified, the field name will
-                # be 'foo', while the queryset is accessible through
-                # 'foo_set'. If the related_name has been set, these
-                # will be the same.
-                fieldname = accessor[:-4] # strip '_set'
-                field, field_model, direct, m2m = \
-                    self._model._meta.get_field_by_name(fieldname)
-            else:
-                raise
+    def get_relation_models(self, rel):
+        if rel.key in self._field_info_cache:
+            return self._field_info_cache[rel.key]
 
-        # We want to get the next model in our iteration
-        if direct and not m2m:
-            # (<related.ForeignKey: publisher>, None, True, False)
-            model = field.related.parent_model
-        elif direct and m2m:
-            # (<related.ManyToManyField: authors>, None, True, True)
-            # (<related.ManyToManyField: tags>, None, True, True)
-            model = field.rel.to
-            info['through'] = self.__class__(field.rel.through)
-        elif not direct and not m2m:
-            # (<RelatedObject: tests:extrabookinfo related to book>, None, False, False)
-            # (<RelatedObject: tests:chapter related to book>, None, False, False)
-            # (<RelatedObject: tests:publisherlink related to publisher>, None, False, False)
-            model = field.model
-        elif not direct and m2m:
-            # (<RelatedObject: tests:category related to books>, None, False, True)
-            model = field.model
-            info['through'] = self.__class__(field.field.rel.through)
+        # print rel, type(rel)
+        # print "backref", repr(rel.backref)
+        # # print rel.remote_side
+        # # print rel.direction
+        # print rel.local_columns
+        # print "model_table  ", self._model.__table__
+        # print "relate_table ", rel.table
+        # print "relate_target", rel.target
+        # print [x for x in dir(rel) if not x.startswith('_')]
+        # # print [x for x in dir(rel.mapper) if not x.startswith('_')]
+        # print
+        # m2m = rel.secondary is not None
+
+        if rel.backref is None:
+            model = self.__class__(rel.mapper.class_)
+            related_model = self
         else:
-            # Cannot happen, we covered all four cases
-            raise AssertionError("Impossible")
-
-        info['model'] = self.__class__(model)
-        info['m2m'] = m2m
-        info['direct'] = direct
-        self._field_info_cache[accessor] = fieldname, info
-        return fieldname, info
+            related_model = self.__class__(rel.mapper.class_)
+            model = self
+        # # We want to get the next model in our iteration
+        # if direct and not m2m:
+        #     # (<related.ForeignKey: publisher>, None, True, False)
+        #     model = field.related.parent_model
+        # elif direct and m2m:
+        #     # (<related.ManyToManyField: authors>, None, True, True)
+        #     # (<related.ManyToManyField: tags>, None, True, True)
+        #     model = field.rel.to
+        #     info['through'] = self.__class__(field.rel.through)
+        # elif not direct and not m2m:
+        #     # (<RelatedObject: tests:extrabookinfo related to book>, None, False, False)
+        #     # (<RelatedObject: tests:chapter related to book>, None, False, False)
+        #     # (<RelatedObject: tests:publisherlink related to publisher>, None, False, False)
+        #     model = field.model
+        # elif not direct and m2m:
+        #     # (<RelatedObject: tests:category related to books>, None, False, True)
+        #     model = field.model
+        #     info['through'] = self.__class__(field.field.rel.through)
+        # else:
+        #     # Cannot happen, we covered all four cases
+        #     raise AssertionError("Impossible")
+        association_model = rel.secondary
+        self._field_info_cache[rel.key] = (model, related_model, association_model)
+        return model, related_model, association_model
 
     # -- fields
 
         return
 
     def _fill_field_info(self):
-        if self._field_info_cache is not None:
+        if self._fieldname_to_relation is not None:
             return
-        self._field_info_cache = {}
-        self._relation_info_cache = {}
+        self._fieldname_to_relation = {}
+        self._relation_name_to_field_name = {}
         for name, field in self._iter_fields():
             relation = self._get_basic_relation(field)
             if relation is not None:
-                self._field_info_cache[name] = relation
-                self._relation_info_cache[relation.key] = name
+                self._fieldname_to_relation[name] = relation
+                self._relation_name_to_field_name[relation.key] = name
 
     def _iter_fields(self, obj=None):
         for field in self._mapper.column_attrs:
         self._fill_field_info()
 
         for name, field in self._iter_fields(obj):
-            if name not in self._field_info_cache:
+            if name not in self._fieldname_to_relation:
                 yield name
 
     # -- relationships
 
-    def _iter_relations(self, obj, cardinality):
+    def iter_relations(self, obj=None, cardinality='any'):
         assert cardinality in ('has one', 'has many', 'any')
 
         self._fill_field_info()
         for rel in self._mapper.relationships:
             cardinal = "has many" if rel.uselist else "has one"
             if cardinality in ('any', cardinal):
-                fk_field = self._relation_info_cache.get(rel.key)
-                yield rel.key, fk_field, rel
-
-    def iter_relations(self, obj, cardinality='any'):
-        for relname, field_id, _ in self._iter_relations(obj, cardinality):
-            yield relname, field_id
+                fk_field = self._relation_name_to_field_name.get(rel.key)
+                yield RelationProperty(
+                    rel, rel.key, cardinal,
+                    *self.get_relation_models(rel),
+                    foreign_key_field=fk_field)
 
     def get_value(self, obj, accessor):
         value = getattr(obj, accessor)
                 result.append(ModelInstancePair(model, res))
             return result
         else:
+            # FIXME: improve this
+            if value is None and accessor in self.relation_map():
+                raise Skip
             return value
 
+class SqlAlchemyCollectionListener(CollectionListener):
+    def connect(self):
+        # Connect to the save signal
+        event.listen(self.model, 'before_update', self.before_update)
+        event.listen(self.model, 'after_update', self.after_update)
+        event.listen(self.model, 'after_insert', self.after_insert)
+        event.listen(self.model, 'before_delete', self.before_delete)
+        event.listen(self.model, 'after_delete', self.after_delete)
+
+        # M2M handling
+        # if self.relation and self.relation.association_model is not None:
+        #     event.listen(self.relation.obj, 'append', self.m2m_changed)
+
+    # -- callbacks
+
+    def before_update(self, mapper, connection, instance):
+        self.pre_update(instance)
+
+    def after_update(self, mapper, connection, instance):
+        self.post_update(instance)
+
+    def after_insert(self, mapper, connection, instance):
+        self.post_insert(instance)
+
+    def before_delete(self, mapper, connection, instance):
+        self.pre_delete(instance)
+
+    def after_delete(self, mapper, connection, instance):
+        self.post_delete(instance)
+
+    def m2m_changed(self, mapper, connection, instance):
+        print "sql m2m", mapper, instance
+
+
+SqlAlchemyModelInspector.listener = SqlAlchemyCollectionListener
 SqlAlchemyModelInspector.register()

denormalize/tests/backend.py

 from pprint import pprint
 import os
-
-from django.conf import settings
+from functools import partial
 
 from ..models import *
 from ..backend.locmem import LocMemBackend
 from ..context import delay_sync, sync_together
 
+import logging
+log = logging.getLogger(__name__)
+
+# TODO: add rollback test
 
 class BackendTestMixin(object):
     SUPPORTS_SYNC_COLLECTION = True
 
     def _create_backend(self):
-        return LocMemBackend()
+        raise NotImplementedError()
 
     def test_dump(self):
         bookcol = self.collection()
         self.assertTrue('tags' in doc, doc)
 
         # Change a one to many link
+        log.info("update book chapter (one-to-many)")
         self.update_chapter()
         doc = backend.get_doc(bookcol, 1)
         chapter = doc['chapter_set'][0]
         self.assertTrue(chapter['id'] == 1)
-        self.assertTrue(chapter['title'].endswith('!!!'), doc)
+        self.assertTrue(chapter['title'].endswith('!!!'), chapter)
 
         # Change something (m2m)
+        log.info("update book author (many-to-many)")
         self.update_author()
         doc = backend.get_doc(bookcol, 1)
         self.assertEqual(doc['authors'][0]['name'], 'Another Name')
 
         # Change something that's shared (m2m)
+        log.info("update book/publisher tag (multiple many-to-many)")
         self.update_tag()
         doc = backend.get_doc(bookcol, 1)
         self.assertTrue('tech' in doc['tags'])
 
         # Add a chapter to a book
+        log.info("insert chapter (one-to-many)")
         book1, chapter = self.insert_chapter()
         doc = backend.get_doc(bookcol, 1)
         self.assertTrue("Conclusion" in (x['title'] for x in doc['chapter_set']))
 
         # Move a chapter to another book (FK change!)
+        log.info("move chapter")
         book2 = self.move_chapter(chapter)
         doc = backend.get_doc(bookcol, 1)
         self.assertFalse("Conclusion" in (x['title'] for x in doc['chapter_set']))
         self.assertTrue("Conclusion" in (x['title'] for x in doc['chapter_set']))
 
         # Delete the chapter
+        log.info("delete chapter (one-to-many)")
         self.delete_chapter(chapter)
         doc = backend.get_doc(bookcol, 2)
         self.assertFalse("Conclusion" in (x['title'] for x in doc['chapter_set']))
 
         # Add a tag (m2m updated!!!) FIXME: not supported yet
+        log.info("add tag to book (many-to-many)")
         tag = self.add_tag(book2)
 
         # Remove the tag from the book from the other side (reverse=True)
+        log.info("remove tag from book (reverse many-to-many)")
         self.remove_tag_from_reverse(tag, book2)
 
         if self.SUPPORTS_SYNC_COLLECTION:
             backend.sync_collection(bookcol)
 
             # Next, with dirty records
-            def inject_dirty():
-                newbook = self.models.Book.objects.create(
-                    title="Some title", publisher=book2.publisher)
-                self.assertTrue(newbook.id in backend._dirty['books'])
-                backend._dirty['books'].add(1)
+            inject_dirty = partial(self.inject_dirty, backend, book2)
             backend._sync_collection_before_handling_dirty = inject_dirty
             backend.sync_collection(bookcol)
 
             # during the full sync.
             self.assertEqual(len(backend.data['books']), 3)
 
+class LocMemBackendTestMixin(BackendTestMixin):
+    def _create_backend(self):
+        return LocMemBackend()
 
-# The MongoDB database to use for tests (required)
-TEST_MONGO_DB = getattr(settings, 'DENORMALIZE_TEST_MONGO_DB', None)
-# Optional, defaults to localhost
-TEST_MONGO_URI = getattr(settings, 'DENORMALIZE_TEST_MONGO_URI', None)
 
-if TEST_MONGO_DB:
-    from ..backend.mongodb import MongoBackend
+class MongoBackendMixin(BackendTestMixin):
+    SUPPORTS_SYNC_COLLECTION = False # FIXME
+    def _create_backend(self):
+        from ..backend.mongodb import MongoBackend
+        return MongoBackend(db_name=TEST_MONGO_DB,
+                            connection_uri=TEST_MONGO_URI)
 
-    class MongoBackendMixin(BackendTestMixin):
-        SUPPORTS_SYNC_COLLECTION = False # FIXME
-        def _create_backend(self):
-            return MongoBackend(db_name=TEST_MONGO_DB,
-                                connection_uri=TEST_MONGO_URI)
-
-else:
-    print "WARNING: skipping MongoDB backend test, because " \
-        "settings.DENORMALIZE_TEST_MONGO_DB is not set!"

denormalize/tests/collections.py

+collections.py

denormalize/tests/common.py

         # Other expectations
 
         # - Tags are pure strings
-        self.assertEqual(doc['tags'], [u'cooking', u'geeks', u'technology'])
+        self.assertItemsEqual(doc['tags'], [u'cooking', u'geeks', u'technology'])
 
         # - ForeignKeys not explicitly followed are included as a pure id
         chapter = doc['chapter_set'][0]
     def test_get_related_models(self):
         bookcol = self.collection()
         deps = bookcol.get_related_models()
-        for rel in bookcol.select_related + bookcol.prefetch_related:
-            # Difference between the filter name and the path name
-            if rel == 'chapter_set':
-                rel = 'chapter'
-            elif rel == 'category_set':
-                rel = 'category'
-            self.assertTrue(rel in deps, "Relation {0} not found!".format(rel))
-        pprint(deps)
-        self.assertIs(deps['publisher']['model']._model, self.models.Publisher)
-        self.assertIs(deps['extra_info']['model']._model, self.models.ExtraBookInfo)
-        self.assertIs(deps['chapter']['model']._model, self.models.Chapter)
-        self.assertIs(deps['authors']['model']._model, self.models.Author)
-        self.assertIs(deps['tags']['model']._model, self.models.Tag)
-        self.assertIs(deps['publisher__links']['model']._model, self.models.PublisherLink)
+        if 'pprint' in os.environ:
+            pprint(deps)
+        self.check_related_models(bookcol, deps)
+

denormalize/tests/test_django/common.py

 
     @classmethod
     def setUpClass(cls, *args, **kwargs):
-        # NOTE: this is not thread-safe
-        cls.collection.model = cls.models.Book
-
-        if not cls._test_models_initiated:
+        # we only want to create the tables once, not once per sub-class
+        if not ModelTestCase._test_models_initiated:
             cls.create_models_from_app(cls.test_models_app)
-            cls._test_models_initiated = True
+            ModelTestCase._test_models_initiated = True
         super(ModelTestCase, cls).setUpClass(*args, **kwargs)
 
     @classmethod

denormalize/tests/test_django/test_backends.py

-from ..backend import BackendTestMixin
+from ..backend import LocMemBackendTestMixin, MongoBackendMixin
 from .common import ModelTestCase
 from . import models
+from django.conf import settings
+from unittest import skipIf
 
-
-class BackendTest(BackendTestMixin, ModelTestCase):
+class LocMemBackendTest(LocMemBackendTestMixin, ModelTestCase):
     models = models
 
     def update_chapter(self):
         # 'clear' action
         tag.book_set = []
         book2.publisher.tags = [tag]
+
+    def inject_dirty(self, backend, book2):
+        newbook = self.models.Book.objects.create(
+            title="Some title", publisher=book2.publisher)
+        self.assertTrue(newbook.id in backend._dirty['books'])
+        backend._dirty['books'].add(1)
+
+# The MongoDB database to use for tests (required)
+TEST_MONGO_DB = getattr(settings, 'DENORMALIZE_TEST_MONGO_DB', None)
+# Optional, defaults to localhost
+TEST_MONGO_URI = getattr(settings, 'DENORMALIZE_TEST_MONGO_URI', None)
+
+@skipIf(not TEST_MONGO_DB, "settings.DENORMALIZE_TEST_MONGO_DB is not set!")
+class MongoBackendTest(MongoBackendMixin, ModelTestCase):
+    models = models

denormalize/tests/test_django/test_collections.py

 
 # CollectionTestMixin must come first so that its setUp gets called
 class DjangoCollectionTest(CollectionTestMixin, ModelTestCase):
-    pass
+    def iter_relationship_filter_names(self, bookcol):
+        for rel in bookcol.select_related + bookcol.prefetch_related:
+            # Difference between the filter name and the path name
+            if rel == 'chapter_set':
+                rel = 'chapter'
+            elif rel == 'category_set':
+                rel = 'category'
+            yield rel
+
+    def check_related_models(self, bookcol, deps):
+        for rel in bookcol.select_related + bookcol.prefetch_related:
+            # Difference between the filter name and the path name
+            if rel == 'chapter_set':
+                rel = 'chapter'
+            elif rel == 'category_set':
+                rel = 'category'
+            self.assertTrue(rel in deps, "Relation {0} not found!".format(rel))
+        self.assertIs(deps['publisher']['model']._model, self.models.Publisher)
+        self.assertIs(deps['extra_info']['model']._model, self.models.ExtraBookInfo)
+        self.assertIs(deps['chapter']['model']._model, self.models.Chapter)
+        self.assertIs(deps['authors']['model']._model, self.models.Author)
+        self.assertIs(deps['tags']['model']._model, self.models.Tag)
+        self.assertIs(deps['publisher__links']['model']._model, self.models.PublisherLink)
 
 class DjangoORMTest(ModelTestCase):
 

denormalize/tests/test_sqlalchemy/common.py

 class BookCollection(common.BookCollection):
     model = models.Book
 
+# FIXME:
+from denormalize.orms.sqlalchemy import SqlAlchemyModelInspector
+SqlAlchemyModelInspector._session = models.Session()
+
+
 class ModelTestCase(unittest2.TestCase):
     _test_models_initiated = False
     models = models
 
     @classmethod
     def setUpClass(cls, *args, **kwargs):
-        # NOTE: this is not thread-safe
-        cls.collection.model = cls.models.Book
-
-        if not cls._test_models_initiated:
-            cls.create_models()
-            cls._test_models_initiated = True
+        # we only want to create the tables once, not once per sub-class
+        if not ModelTestCase._test_models_initiated:
+            models.create_test_tables()
+            ModelTestCase._test_models_initiated = True
         super(ModelTestCase, cls).setUpClass(*args, **kwargs)
 
     @classmethod
-    def create_models(cls):
-        models.Base.metadata.create_all(models.engine)
+    def tearDown(cls, *args, **kwargs):
+        models.Session.remove()

denormalize/tests/test_sqlalchemy/models.py

     from sqlalchemy.ext.declarative.base import declarative_base # 0.9
 from sqlalchemy import create_engine, Table, Column, Integer, String, \
     DateTime, Text, ForeignKey
-from sqlalchemy.orm import sessionmaker, relationship, synonym, backref
+from sqlalchemy.orm import sessionmaker, scoped_session, relationship, synonym, backref
+from sqlalchemy.orm import eagerload, eagerload_all, joinedload, class_mapper, subqueryload
+from sqlalchemy.orm.attributes import InstrumentedAttribute
 from sqlalchemy.ext.hybrid import hybrid_property
 
 import datetime
 
 Base = declarative_base()
 
-engine = create_engine('sqlite://')
+engine = create_engine('sqlite://', echo=False)
 
-Session = sessionmaker(bind=engine)
+Session = scoped_session(sessionmaker(bind=engine))
 
 def associate(table1, table2):
     return Table(
     def __unicode__(self):
         return self.name
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.name)
+
     # class Meta:
     #     verbose_name = 'tag'
     #     verbose_name_plural = 'tags'
     def __unicode__(self):
         return self.name
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.name)
+
     # class Meta:
     #     verbose_name = 'author'
     #     verbose_name_plural = 'authors'
     def __unicode__(self):
         return self.name
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.name)
+
     # class Meta:
     #     verbose_name = 'publisher'
     #     verbose_name_plural = 'publishers'
     def __unicode__(self):
         return self.url
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.url)
+
     # class Meta:
     #     verbose_name = 'publisher link'
     #     verbose_name_plural = 'publisher links'
     publisher = relationship(Publisher, backref='books')
 
     tags = relationship(Tag, secondary=associate('book', 'tag'),
-                        backref='books')
+                        backref='book_set')
     created = Column(DateTime, default=datetime.datetime.utcnow)
 
     def __unicode__(self):
         return 'Book %s' % (self.id,)
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.id)
+
     # class Meta:
     #     verbose_name = 'book'
     #     verbose_name_plural = 'books'
     def __unicode__(self):
         return 'Chapter %s' % (self.id,)
 
+    def __repr__(self):
+        return '%s(%r, %r)' % (self.__class__.__name__, self.id, self.title)
+
     # class Meta:
     #     verbose_name = 'chapter'
     #     verbose_name_plural = 'chapters'
     def __unicode__(self):
         return 'ExtraBookInfo %s' % (self.id,)
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.id)
+
     # class Meta:
     #     verbose_name = 'extra book info'
     #     verbose_name_plural = 'extra book info'
     # To create a reverse m2m relationship for testing
     name = Column(String(80))
     books = relationship(Book, secondary=associate('category', 'book'),
-                         backref='categories')
+                         backref='category_set')
 
     def __unicode__(self):
         return self.name
 
+    def __repr__(self):
+        return '%s(%r)' % (self.__class__.__name__, self.name)
+
     # class Meta:
     #     verbose_name = 'category'
     #     verbose_name_plural = 'categories'
 
+def create_test_tables():
+    Base.metadata.create_all(engine)
 
 def create_test_data():
     session = Session()
     session.add(cooking_for_geeks)
     session.add(mongodb)
     session.commit()
-    fresh_book = session.query(Book).filter_by(title=u"Cooking for Geeks").one()
-    return {'book': fresh_book}
+
+    import inspect
+
+    opts = []
+    for obj in class_mapper(Book).relationships:
+        if obj.secondary is None:
+            attr = getattr(Book, obj.key)
+            print obj, attr
+            opts.append(subqueryload(attr))
+
+    # attrs = [obj for name, obj in inspect.getmembers(Book) if isinstance(obj, InstrumentedAttribute)]
+    # rels = [obj.key for obj in class_mapper(Book).relationships if obj.secondary is None]
+
+    # print list(class_mapper(Book).relationships)
+
+    # opts = eagerload_all(*rels)
+    # in order for this to work, fresh_book needs to either be bound to a session
+    # or needs to recursively eagerload all relationships
+    fresh_book = session.query(Book).options(*opts).filter_by(title=u"Cooking for Geeks").one()
+    # return {'book': fresh_book}
+
+    return {'book': cooking_for_geeks}
 
 
 ################################################################

denormalize/tests/test_sqlalchemy/test_backends.py

+from ..backend import LocMemBackendTestMixin
+from .common import ModelTestCase
+from . import models
+
+session = models.Session()
+
+class LocMemBackendTest(LocMemBackendTestMixin, ModelTestCase):
+    models = models
+
+    def update_chapter(self):
+        chapter = session.query(models.Chapter).filter_by(id=1).one()
+        chapter.title += u'!!!'
+        session.commit()
+        return chapter
+
+    def update_author(self):
+        author = session.query(models.Author).filter_by(id=1).one()
+        author.name = 'Another Name'
+        author.email = 'foo@example.com'
+        session.commit()
+        return author
+
+    def update_tag(self):
+        tag = session.query(models.Tag).filter_by(name="technology").one()
+        tag.name = 'tech'
+        session.commit()
+        return tag
+
+    def insert_chapter(self):
+        book1 = session.query(models.Book).filter_by(id=1).one()
+        chapter = models.Chapter(book=book1, title="Conclusion")
+        session.commit()
+        return book1, chapter
+
+    def move_chapter(self, chapter):
+        book2 = session.query(models.Book).filter_by(id=2).one()
+        chapter.book = book2
+        session.commit()
+        return book2
+
+    def delete_chapter(self, chapter):
+        # Delete the chapter
+        session.delete(chapter)
+        session.commit()
+
+    def add_tag(self, book2):
+        tag = models.Tag(name="foo")
+        book2.tags = []
+        book2.tags.append(tag)
+        session.commit()
+        return tag
+
+    def remove_tag_from_reverse(self, tag, book2):
+        tag = session.query(models.Tag).filter_by(id=tag.id).one()
+        tag.book_set.remove(book2)
+        tag.book_set.append(book2)
+        # 'clear' action
+        tag.book_set = []
+        book2.publisher.tags = [tag]
+        session.commit()
+
+    def inject_dirty(self, backend, book2):
+        newbook = self.models.Book(
+            title="Some title", publisher=book2.publisher)
+        session.commit()
+        self.assertTrue(newbook.id in backend._dirty['books'])
+        backend._dirty['books'].add(1)

denormalize/tests/test_sqlalchemy/test_collections.py

             models.engine.execute(table.delete())
         super(SqlAlchemyCollectionTest, cls).tearDownClass()
 
+    def check_related_models(self, bookcol, deps):
+        for rel in bookcol.select_related + bookcol.prefetch_related:
+            # Difference between the filter name and the path name
+            self.assertTrue(rel in deps, "Relation {0} not found!".format(rel))
+        self.assertIs(deps['publisher']['model']._model, self.models.Publisher)
+        self.assertIs(deps['extra_info']['model']._model, self.models.ExtraBookInfo)
+        self.assertIs(deps['chapter_set']['model']._model, self.models.Chapter)
+        self.assertIs(deps['authors']['model']._model, self.models.Author)
+        self.assertIs(deps['tags']['model']._model, self.models.Tag)
+        self.assertIs(deps['publisher__links']['model']._model, self.models.PublisherLink)
+
 # class DjangoORMTest(ModelTestCase):
 #     def test_reverse_field_registration(self):
 #         self.assertTrue(hasattr(models.A, 'b_set'))