Mike Bayer avatar Mike Bayer committed f5458e5

- merged sync_simplify branch
- The methodology behind "primaryjoin"/"secondaryjoin" has
been refactored. Behavior should be slightly more
intelligent, primarily in terms of error messages which
have been pared down to be more readable. In a slight
number of scenarios it can better resolve the correct
foreign key than before.
- moved collections unit test from relationships.py to collection.py
- PropertyLoader now has "synchronize_pairs" and "equated_pairs"
collections which allow easy access to the source/destination
parent/child relation between columns (might change names)
- factored out ClauseSynchronizer (finally)
- added many more tests for priamryjoin/secondaryjoin
error checks

Comments (0)

Files changed (12)

     - Added a more aggressive check for "uncompiled mappers",
       helps particularly with declarative layer [ticket:995]
 
+    - The methodology behind "primaryjoin"/"secondaryjoin" has
+      been refactored.  Behavior should be slightly more
+      intelligent, primarily in terms of error messages which
+      have been pared down to be more readable.  In a slight
+      number of scenarios it can better resolve the correct 
+      foreign key than before.
+
     - Added comparable_property(), adds query Comparator behavior
       to regular, unmanaged Python properties
 

lib/sqlalchemy/orm/dependency.py

 """
 
 from sqlalchemy.orm import sync
-from sqlalchemy.orm.sync import ONETOMANY,MANYTOONE,MANYTOMANY
 from sqlalchemy import sql, util, exceptions
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY
 
 
 def create_dependency_processor(prop):
         self.passive_updates = prop.passive_updates
         self.enable_typechecks = prop.enable_typechecks
         self.key = prop.key
-
-        self._compile_synchronizers()
+        if not self.prop.synchronize_pairs:
+            raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s.  No target attributes to populate between parent and child are present" % self.prop)
 
     def _get_instrumented_attribute(self):
         """Return the ``InstrumentedAttribute`` handled by this
 
         raise NotImplementedError()
 
-    def _compile_synchronizers(self):
-        """Assemble a list of *synchronization rules*.
-
-        These are fired to populate attributes from one side
-        of a relation to another.
-        """
-
-        self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
-        if self.direction == sync.MANYTOMANY:
-            self.syncrules.compile(self.prop.primaryjoin, issecondary=False, foreign_keys=self.foreign_keys)
-            self.syncrules.compile(self.prop.secondaryjoin, issecondary=True, foreign_keys=self.foreign_keys)
-        else:
-            self.syncrules.compile(self.prop.primaryjoin, foreign_keys=self.foreign_keys)
-
 
     def _conditional_post_update(self, state, uowcommit, related):
         """Execute a post_update call.
         if state is not None and self.post_update:
             for x in related:
                 if x is not None:
-                    uowcommit.register_object(state, postupdate=True, post_update_cols=self.syncrules.dest_columns())
+                    uowcommit.register_object(state, postupdate=True, post_update_cols=[r for l, r in self.prop.synchronize_pairs])
                     break
 
     def _pks_changed(self, uowcommit, state):
-        return self.syncrules.source_changes(uowcommit, state)
+        raise NotImplementedError()
 
     def __str__(self):
         return "%s(%s)" % (self.__class__.__name__, str(self.prop))
         if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
             return
         self._verify_canload(child)
-        self.syncrules.execute(source, dest, source, child, clearkeys)
+        if clearkeys:
+            sync.clear(dest, self.mapper, self.prop.synchronize_pairs)
+        else:
+            sync.populate(source, self.parent, dest, self.mapper, self.prop.synchronize_pairs)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class DetectKeySwitch(DependencyProcessor):
     """a special DP that works for many-to-one relations, fires off for
                     elem.dict[self.key]._state in switchers
                 ]:
                 uowcommit.register_object(s, listonly=self.passive_updates)
-                self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+                sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs)
+                #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.mapper, self.prop.synchronize_pairs)
 
 class ManyToOneDP(DependencyProcessor):
     def __init__(self, prop):
 
 
     def _synchronize(self, state, child, associationrow, clearkeys, uowcommit):
-        source = child
-        dest = state
-        if dest is None or (not self.post_update and uowcommit.is_deleted(dest)):
+        if state is None or (not self.post_update and uowcommit.is_deleted(state)):
             return
-        self._verify_canload(child)
-        self.syncrules.execute(source, dest, dest, child, clearkeys)
+
+        if clearkeys or child is None:
+            sync.clear(state, self.parent, self.prop.synchronize_pairs)
+        else:
+            self._verify_canload(child)
+            sync.populate(child, self.mapper, state, self.parent, self.prop.synchronize_pairs)
 
 class ManyToManyDP(DependencyProcessor):
     def register_dependencies(self, uowcommit):
                 if not self.passive_updates and unchanged and self._pks_changed(uowcommit, state):
                     for child in unchanged:
                         associationrow = {}
-                        self.syncrules.update(associationrow, state, child, "old_")
+                        sync.update(state, self.parent, associationrow, "old_", self.prop.synchronize_pairs)
+                        sync.update(child, self.mapper, associationrow, "old_", self.prop.secondary_synchronize_pairs)
+
+                        #self.syncrules.update(associationrow, state, child, "old_")
                         secondary_update.append(associationrow)
 
         if secondary_delete:
         if associationrow is None:
             return
         self._verify_canload(child)
-        self.syncrules.execute(None, associationrow, state, child, clearkeys)
+        
+        sync.populate_dict(state, self.parent, associationrow, self.prop.synchronize_pairs)
+        sync.populate_dict(child, self.mapper, associationrow, self.prop.secondary_synchronize_pairs)
+
+    def _pks_changed(self, uowcommit, state):
+        return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs)
 
 class AssociationDP(OneToManyDP):
     def __init__(self, *args, **kwargs):

lib/sqlalchemy/orm/interfaces.py

 EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
 EXT_STOP = util.symbol('EXT_STOP')
 
+ONETOMANY = util.symbol('ONETOMANY')
+MANYTOONE = util.symbol('MANYTOONE')
+MANYTOMANY = util.symbol('MANYTOMANY')
+
 class MapperExtension(object):
     """Base implementation for customizing Mapper behavior.
 

lib/sqlalchemy/orm/mapper.py

         self._dependency_processors = []
         self._clause_adapter = None
         self._requires_row_aliasing = False
-
+        self.__inherits_equated_pairs = None
+        
         if not issubclass(class_, object):
             raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
 
         self.__should_log_info = logging.is_info_enabled(self.logger)
         self.__should_log_debug = logging.is_debug_enabled(self.logger)
 
-        self._compile_class()
-        self._compile_inheritance()
-        self._compile_extensions()
-        self._compile_properties()
-        self._compile_pks()
+        self.__compile_class()
+        self.__compile_inheritance()
+        self.__compile_extensions()
+        self.__compile_properties()
+        self.__compile_pks()
         global __new_mappers
         __new_mappers = True
         self.__log("constructed")
         to execute once all mappers have been constructed.
         """
 
-        self.__log("_initialize_properties() started")
+        self.__log("__initialize_properties() started")
         l = [(key, prop) for key, prop in self.__props.iteritems()]
         for key, prop in l:
             self.__log("initialize prop " + key)
             if getattr(prop, 'key', None) is None:
                 prop.init(key, self)
-        self.__log("_initialize_properties() complete")
+        self.__log("__initialize_properties() complete")
         self.__props_init = True
 
 
-    def _compile_extensions(self):
+    def __compile_extensions(self):
         """Go through the global_extensions list as well as the list
         of ``MapperExtensions`` specified for this ``Mapper`` and
         creates a linked list of those extensions.
         for ext in extlist:
             self.extension.append(ext)
 
-    def _compile_inheritance(self):
+    def __compile_inheritance(self):
         """Configure settings related to inherting and/or inherited mappers being present."""
 
         if self.inherits:
                 self.single = True
             if not self.local_table is self.inherits.local_table:
                 if self.concrete:
-                    self._synchronizer = None
                     self.mapped_table = self.local_table
                     for mapper in self.iterate_to_root():
                         if mapper.polymorphic_on:
                         # stuff we dont want (allows test/inheritance.InheritTest4 to pass)
                         self.inherit_condition = sql.join(self.inherits.local_table, self.local_table).onclause
                     self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition)
-                    # generate sync rules.  similarly to creating the on clause, specify a
-                    # stricter set of tables to create "sync rules" by,based on the immediate
-                    # inherited table, rather than all inherited tables
-                    self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
-                    if self.inherit_foreign_keys:
-                        fks = util.Set(self.inherit_foreign_keys)
-                    else:
-                        fks = None
-                    self._synchronizer.compile(self.mapped_table.onclause, foreign_keys=fks)
+                    
+                    fks = util.to_set(self.inherit_foreign_keys)
+                    self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks)
             else:
-                self._synchronizer = None
                 self.mapped_table = self.local_table
             if self.polymorphic_identity is not None:
                 self.inherits.polymorphic_map[self.polymorphic_identity] = self
         else:
             self._all_tables = util.Set()
             self.base_mapper = self
-            self._synchronizer = None
             self.mapped_table = self.local_table
             if self.polymorphic_identity:
                 if self.polymorphic_on is None:
         if self.mapped_table is None:
             raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified.  (Are you using the return value of table.create()?  It no longer has a return value.)" % str(self))
 
-    def _compile_pks(self):
+    def __compile_pks(self):
 
         self.tables = sqlutil.find_tables(self.mapped_table)
 
 
             return getattr(getattr(cls, clskey), key)
 
-    def _compile_properties(self):
+    def __compile_properties(self):
 
         # object attribute names mapped to MapperProperty objects
         self.__props = util.OrderedDict()
         for mapper in self._inheriting_mappers:
             mapper._adapt_inherited_property(key, prop)
 
-    def _compile_class(self):
+    def __compile_class(self):
         """If this mapper is to be a primary mapper (i.e. the
         non_primary flag is not set), associate this Mapper with the
         given class_ and entity name.
                     # TODO: this fires off more than needed, try to organize syncrules
                     # per table
                     for m in util.reversed(list(mapper.iterate_to_root())):
-                        if m._synchronizer:
-                            m._synchronizer.execute(state, state)
+                        if m.__inherits_equated_pairs:
+                            m._synchronize_inherited(state)
 
                     # testlib.pragma exempt:__hash__
                     inserted_objects.add((state, connection))
                     if 'after_update' in mapper.extension.methods:
                         mapper.extension.after_update(mapper, connection, state.obj())
 
+    def _synchronize_inherited(self, state):
+        sync.populate(state, self, state, self, self.__inherits_equated_pairs)
+
     def _postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params):
         """After an ``INSERT`` or ``UPDATE``, assemble newly generated
         values on an instance.  For columns which are marked as being generated

lib/sqlalchemy/orm/properties.py

 """
 
 from sqlalchemy import sql, schema, util, exceptions, logging
-from sqlalchemy.sql.util import ClauseAdapter
+from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns
 from sqlalchemy.sql import visitors, operators, ColumnElement
 from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm.mapper import _class_to_mapper
 from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
-from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
+from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY
 from sqlalchemy.exceptions import ArgumentError
 
-
 __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty',
            'ComparableProperty', 'PropertyLoader', 'BackRef')
 
             
         def __eq__(self, other):
             if other is None:
-                if self.prop.direction == sync.ONETOMANY:
+                if self.prop.direction == ONETOMANY:
                     return ~sql.exists([1], self.prop.primaryjoin)
                 else:
                     return self.prop._optimized_compare(None)
             
         def __ne__(self, other):
             if other is None:
-                if self.prop.direction == sync.MANYTOONE:
+                if self.prop.direction == MANYTOONE:
                     return sql.or_(*[x!=None for x in self.prop.foreign_keys])
                 elif self.prop.uselist:
                     return self.any()
             return self.argument.class_
 
     def do_init(self):
-        self._determine_targets()
-        self._determine_joins()
-        self._determine_fks()
-        self._determine_direction()
-        self._determine_remote_side()
+        self.__determine_targets()
+        self.__determine_joins()
+        self.__determine_fks()
+        self.__determine_direction()
+        self.__determine_remote_side()
         self._post_init()
 
-    def _determine_targets(self):
+    def __determine_targets(self):
         if isinstance(self.argument, type):
             self.mapper = mapper.class_mapper(self.argument, entity_name=self.entity_name, compile=False)
         elif isinstance(self.argument, mapper.Mapper):
 
         if self.cascade.delete_orphan:
             if self.parent.class_ is self.mapper.class_:
-                raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade rule on a self-referential relationship.  You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
+                raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade "
+                            "rule on a self-referential relationship.  "
+                            "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self)))
             self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_))
 
-    def _determine_joins(self):
+    def __determine_joins(self):
         if self.secondaryjoin is not None and self.secondary is None:
             raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument")
         # if join conditions were not specified, figure them out based on foreign keys
                 if self.primaryjoin is None:
                     self.primaryjoin = _search_for_join(self.parent, self.target).onclause
         except exceptions.ArgumentError, e:
-            raise exceptions.ArgumentError("""Error determining primary and/or secondary join for relationship '%s'. If the underlying error cannot be corrected, you should specify the 'primaryjoin' (and 'secondaryjoin', if there is an association table present) keyword arguments to the relation() function (or for backrefs, by specifying the backref using the backref() function with keyword arguments) to explicitly specify the join conditions. Nested error is \"%s\"""" % (str(self), str(e)))
+            raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s.  "
+                        "Specify a 'primaryjoin' expression.  If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self))
 
 
-    def _col_is_part_of_mappings(self, column):
+    def __col_is_part_of_mappings(self, column):
         if self.secondary is None:
             return self.parent.mapped_table.c.contains_column(column) or \
                 self.target.c.contains_column(column)
                 self.target.c.contains_column(column) or \
                 self.secondary.c.contains_column(column) is not None
         
-    def _determine_fks(self):
+    def __determine_fks(self):
         if self._legacy_foreignkey and not self._refers_to_parent_table():
             self.foreign_keys = self._legacy_foreignkey
 
-        self._opposite_side = util.Set()
+        arg_foreign_keys = self.foreign_keys
+        
+        eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=self.viewonly)
+        eq_pairs = [(l, r) for l, r in eq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
 
-        if self.foreign_keys:
-            def visit_binary(binary):
-                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                    return
-                if binary.left in self.foreign_keys:
-                    self._opposite_side.add(binary.right)
-                if binary.right in self.foreign_keys:
-                    self._opposite_side.add(binary.left)
+        if not eq_pairs:
+            if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+                raise exceptions.ArgumentError("Could not locate any equated column pairs for primaryjoin condition '%s' on relation %s. "
+                    "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.primaryjoin, self)
+                )
+            else:
+                raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. "
+                "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self))
+        
+        self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs])
+        self._opposite_side = util.OrderedSet([l for l, r in eq_pairs])
+        self.synchronize_pairs = eq_pairs
+        
+        if self.secondaryjoin:
+            sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys)
+            sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+            
+            if not sq_pairs:
+                if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True):
+                    raise exceptions.ArgumentError("Could not locate any equated column pairs for secondaryjoin condition '%s' on relation %s. "
+                        "If no equated pairs exist, the relation must be marked as viewonly=True." % (self.secondaryjoin, self)
+                    )
+                else:
+                    raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. "
+                    "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self))
+
+            self.foreign_keys.update([r for l, r in sq_pairs])
+            self._opposite_side.update([l for l, r in sq_pairs])
+            self.secondary_synchronize_pairs = sq_pairs
         else:
-            self.foreign_keys = util.Set()
-            def visit_binary(binary):
-                if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                    return
+            self.secondary_synchronize_pairs = None
+    
+    def equated_pairs(self):
+        return zip(self.local_side, self.remote_side)
+    equated_pairs = property(equated_pairs)
+    
+    def __determine_remote_side(self):
+        if self.remote_side:
+            if self.direction is MANYTOONE:
+                eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_referenced_keys=self.remote_side, any_operator=True)
+            else:
+                eq_pairs = criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=self.remote_side, any_operator=True)
 
-                # this check is for when the user put the "view_only" flag on and has tables that have nothing
-                # to do with the relationship's parent/child mappings in the join conditions.  we dont want cols
-                # or clauses related to those external tables dealt with.  see orm.relationships.ViewOnlyTest
-                if not self._col_is_part_of_mappings(binary.left) or not self._col_is_part_of_mappings(binary.right):
-                    return
+            if self.secondaryjoin:
+                sq_pairs = criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=self.foreign_keys, any_operator=True)
+                sq_pairs = [(l, r) for l, r in sq_pairs if self.__col_is_part_of_mappings(l) and self.__col_is_part_of_mappings(r)]
+                eq_pairs += sq_pairs
+        else:
+            eq_pairs = zip(self._opposite_side, self.foreign_keys)
 
-                for f in binary.left.foreign_keys:
-                    if f.references(binary.right.table):
-                        self.foreign_keys.add(binary.left)
-                        self._opposite_side.add(binary.right)
-                for f in binary.right.foreign_keys:
-                    if f.references(binary.left.table):
-                        self.foreign_keys.add(binary.right)
-                        self._opposite_side.add(binary.left)
-
-        visitors.traverse(self.primaryjoin, visit_binary=visit_binary)
-
-        if not self.foreign_keys:
-            raise exceptions.ArgumentError(
-                "Can't locate any foreign key columns in primary join "
-                "condition '%s' for relationship '%s'.  Specify "
-                "'foreign_keys' argument to indicate which columns in "
-                "the join condition are foreign." %(str(self.primaryjoin), str(self)))
-
-        if self.secondaryjoin is not None:
-            visitors.traverse(self.secondaryjoin, visit_binary=visit_binary)
-
-
-    def _determine_direction(self):
+        if self.direction is MANYTOONE:
+            self.remote_side, self.local_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+        else:
+            self.local_side, self.remote_side = [util.OrderedSet(s) for s in zip(*eq_pairs)]
+            
+    def __determine_direction(self):
         """Determine our *direction*, i.e. do we represent one to
         many, many to many, etc.
         """
 
         if self.secondaryjoin is not None:
-            self.direction = sync.MANYTOMANY
+            self.direction = MANYTOMANY
         elif self._refers_to_parent_table():
             # for a self referential mapper, if the "foreignkey" is a single or composite primary key,
             # then we are "many to one", since the remote site of the relationship identifies a singular entity.
             if self._legacy_foreignkey:
                 for f in self._legacy_foreignkey:
                     if not f.primary_key:
-                        self.direction = sync.ONETOMANY
+                        self.direction = ONETOMANY
                     else:
-                        self.direction = sync.MANYTOONE
+                        self.direction = MANYTOONE
 
             elif self.remote_side:
                 for f in self.foreign_keys:
                     if f in self.remote_side:
-                        self.direction = sync.ONETOMANY
+                        self.direction = ONETOMANY
                         return
                 else:
-                    self.direction = sync.MANYTOONE
+                    self.direction = MANYTOONE
             else:
-                self.direction = sync.ONETOMANY
+                self.direction = ONETOMANY
         else:
             for mappedtable, parenttable in [(self.mapper.mapped_table, self.parent.mapped_table), (self.mapper.local_table, self.parent.local_table)]:
                 onetomany = [c for c in self.foreign_keys if mappedtable.c.contains_column(c)]
                 elif onetomany and manytoone:
                     continue
                 elif onetomany:
-                    self.direction = sync.ONETOMANY
+                    self.direction = ONETOMANY
                     break
                 elif manytoone:
-                    self.direction = sync.MANYTOONE
+                    self.direction = MANYTOONE
                     break
             else:
                 raise exceptions.ArgumentError(
                     "the child's mapped tables.  Specify 'foreign_keys' "
                     "argument." % (str(self)))
 
-    def _determine_remote_side(self):
-        if not self.remote_side:
-            if self.direction is sync.MANYTOONE:
-                self.remote_side = util.Set(self._opposite_side)
-            elif self.direction is sync.ONETOMANY or self.direction is sync.MANYTOMANY:
-                self.remote_side = util.Set(self.foreign_keys)
-
-        self.local_side = util.Set(self._opposite_side).union(util.Set(self.foreign_keys)).difference(self.remote_side)
-
     def _post_init(self):
         if logging.is_info_enabled(self.logger):
             self.logger.info(str(self) + " setup primary join " + str(self.primaryjoin))
             self.logger.info(str(self) + " setup secondary join " + str(self.secondaryjoin))
-            self.logger.info(str(self) + " foreign keys " + str([str(c) for c in self.foreign_keys]))
-            self.logger.info(str(self) + " remote columns " + str([str(c) for c in self.remote_side]))
-            self.logger.info(str(self) + " relation direction " + (self.direction is sync.ONETOMANY and "one-to-many" or (self.direction is sync.MANYTOONE and "many-to-one" or "many-to-many")))
+            self.logger.info(str(self) + " synchronize pairs " + ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs]))
+            self.logger.info(str(self) + " equated pairs " + ",".join(["(%s == %s)" % (l, r) for l, r in self.equated_pairs]))
+            self.logger.info(str(self) + " relation direction " + (self.direction is ONETOMANY and "one-to-many" or (self.direction is MANYTOONE and "many-to-one" or "many-to-many")))
 
-        if self.uselist is None and self.direction is sync.MANYTOONE:
+        if self.uselist is None and self.direction is MANYTOONE:
             self.uselist = False
 
         if self.uselist is None:
             primaryjoin = self.primaryjoin
             
             if fromselectable is not frommapper.local_table:
-                if self.direction is sync.ONETOMANY:
+                if self.direction is ONETOMANY:
                     primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
-                elif self.direction is sync.MANYTOONE:
+                elif self.direction is MANYTOONE:
                     primaryjoin = ClauseAdapter(fromselectable, include=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)
                 elif self.secondaryjoin:
                     primaryjoin = ClauseAdapter(fromselectable, exclude=self.foreign_keys, equivalents=parent_equivalents).traverse(primaryjoin)

lib/sqlalchemy/orm/strategies.py

 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self.parent_property)
+        (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property)
         
-        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
+        self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere))
 
         # determine if our "lazywhere" clause is the same as the mapper's
         # get() clause.  then we can just use mapper.get()
         #from sqlalchemy.orm import query
-        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.lazywhere)
+        self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
         if self.use_get:
             self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads")
 
             return self._lazy_none_clause(reverse_direction)
             
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+            (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_bindparam(bindparam):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
     
     def _lazy_none_clause(self, reverse_direction=False):
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
+            (criterion, bind_to_col, rev) = (self.__lazywhere, self.__bind_to_col, self._equated_columns)
         else:
-            (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
-        bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
+            (criterion, bind_to_col, rev) = LazyLoader.__create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
 
         def visit_binary(binary):
             mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent
                     instance._state.reset(self.key)
             return (new_execute, None, None)
 
-    def _create_lazy_clause(cls, prop, reverse_direction=False):
-        (primaryjoin, secondaryjoin, remote_side) = (prop.primaryjoin, prop.secondaryjoin, prop.remote_side)
-        
+    def __create_lazy_clause(cls, prop, reverse_direction=False):
         binds = {}
         equated_columns = {}
 
+        secondaryjoin = prop.secondaryjoin
+        equated = dict(prop.equated_pairs)
+        
         def should_bind(targetcol, othercol):
-            if not prop._col_is_part_of_mappings(targetcol):
-                return False
-                
             if reverse_direction and not secondaryjoin:
-                return targetcol in remote_side
+                return othercol in equated
             else:
-                return othercol in remote_side
+                return targetcol in equated
 
         def visit_binary(binary):
-            if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
-                return
             leftcol = binary.left
             rightcol = binary.right
 
             equated_columns[leftcol] = rightcol
 
             if should_bind(leftcol, rightcol):
-                if leftcol in binds:
-                    binary.left = binds[leftcol]
-                else:
-                    binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+                if leftcol not in binds:
+                    binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
+                binary.left = binds[leftcol]
+            elif should_bind(rightcol, leftcol):
+                if rightcol not in binds:
+                    binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
+                binary.right = binds[rightcol]
 
-            # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
-            # which can happen in rare cases (test/orm/relationships.py RelationTest2)
-            if leftcol is not rightcol and should_bind(rightcol, leftcol):
-                if rightcol in binds:
-                    binary.right = binds[rightcol]
-                else:
-                    binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
-
-                
-        lazywhere = primaryjoin
+        lazywhere = prop.primaryjoin
         
-        if not secondaryjoin or not reverse_direction:
+        if not prop.secondaryjoin or not reverse_direction:
             lazywhere = visitors.traverse(lazywhere, clone=True, visit_binary=visit_binary)
         
-        if secondaryjoin is not None:
+        if prop.secondaryjoin is not None:
             if reverse_direction:
                 secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
-        return (lazywhere, binds, equated_columns)
-    _create_lazy_clause = classmethod(_create_lazy_clause)
+    
+        bind_to_col = dict([(binds[col].key, col) for col in binds])
+        
+        return (lazywhere, bind_to_col, equated_columns)
+    __create_lazy_clause = classmethod(__create_lazy_clause)
     
 LazyLoader.logger = logging.class_logger(LazyLoader)
 
             ident = []
             allnulls = True
             for primary_key in prop.mapper.primary_key: 
-                val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+                val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key])
                 allnulls = allnulls and val is None
                 ident.append(val)
             if allnulls:

lib/sqlalchemy/orm/sync.py

 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
-"""Contains the ClauseSynchronizer class, which is used to map
-attributes between two objects in a manner corresponding to a SQL
-clause that compares column values.
+"""private module containing functions used for copying data between instances
+based on join conditions.
 """
 
 from sqlalchemy import schema, exceptions, util
-from sqlalchemy.sql import visitors, operators
+from sqlalchemy.sql import visitors, operators, util as sqlutil
 from sqlalchemy import logging
 from sqlalchemy.orm import util as mapperutil
+from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY  # legacy
 
-ONETOMANY = 0
-MANYTOONE = 1
-MANYTOMANY = 2
+def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, dest_mapper, r)
 
-class ClauseSynchronizer(object):
-    """Given a SQL clause, usually a series of one or more binary
-    expressions between columns, and a set of 'source' and
-    'destination' mappers, compiles a set of SyncRules corresponding
-    to that information.
+        try:
+            dest_mapper._set_state_attr_by_column(dest, r, value)
+        except exceptions.UnmappedColumnError:
+            self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r)
 
-    The ClauseSynchronizer can then be executed given a set of
-    parent/child objects or destination dictionary, which will iterate
-    through each of its SyncRules and execute them.  Each SyncRule
-    will copy the value of a single attribute from the parent to the
-    child, corresponding to the pair of columns in a particular binary
-    expression, using the source and destination mappers to map those
-    two columns to object attributes within parent and child.
-    """
+def clear(dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        if r.primary_key:
+            raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest)))
+        try:
+            dest_mapper._set_state_attr_by_column(dest, r, None)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(True, None, l, dest_mapper, r)
 
-    def __init__(self, parent_mapper, child_mapper, direction):
-        self.parent_mapper = parent_mapper
-        self.child_mapper = child_mapper
-        self.direction = direction
-        self.syncrules = []
+def update(source, source_mapper, dest, old_prefix, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l)
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            self._raise_col_to_prop(False, source_mapper, l, None, r)
+        dest[r.key] = value
+        dest[old_prefix + r.key] = oldvalue
 
-    def compile(self, sqlclause, foreign_keys=None, issecondary=None):
-        def compile_binary(binary):
-            """Assemble a SyncRule given a single binary condition."""
+def populate_dict(source, source_mapper, dict_, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            value = source_mapper._get_state_attr_by_column(source, l)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, None, r)
+            
+        dict_[r.key] = value
 
-            if binary.operator != operators.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                return
+def source_changes(uowcommit, source, source_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            prop = source_mapper._get_col_to_prop(l)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(False, source_mapper, l, None, r)
+        (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
+        if added and deleted:
+            return True
+    else:
+        return False
 
-            source_column = None
-            dest_column = None
+def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs):
+    for l, r in synchronize_pairs:
+        try:
+            prop = dest_mapper._get_col_to_prop(r)
+        except exceptions.UnmappedColumnError:
+            _raise_col_to_prop(True, None, l, dest_mapper, r)
+        (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True)
+        if added and deleted:
+            return True
+    else:
+        return False
 
-            if foreign_keys is None:
-                if binary.left.table == binary.right.table:
-                    raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")
-
-                if binary.left in util.Set([f.column for f in binary.right.foreign_keys]):
-                    dest_column = binary.right
-                    source_column = binary.left
-                elif binary.right in util.Set([f.column for f in binary.left.foreign_keys]):
-                    dest_column = binary.left
-                    source_column = binary.right
-            else:
-                if binary.left in foreign_keys:
-                    source_column = binary.right
-                    dest_column = binary.left
-                elif binary.right in foreign_keys:
-                    source_column = binary.left
-                    dest_column = binary.right
-
-            if source_column and dest_column:
-                if self.direction == ONETOMANY:
-                    self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
-                elif self.direction == MANYTOONE:
-                    self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper))
-                else:
-                    if not issecondary:
-                        self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary))
-                    else:
-                        self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))
-
-        rules_added = len(self.syncrules)
-        visitors.traverse(sqlclause, visit_binary=compile_binary)
-        if len(self.syncrules) == rules_added:
-            raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
-
-    def dest_columns(self):
-        return [r.dest_column for r in self.syncrules if r.dest_column is not None]
-
-    def update(self, dest, parent, child, old_prefix):
-        for rule in self.syncrules:
-            rule.update(dest, parent, child, old_prefix)
+def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column):
+    if isdest:
+        raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper))
+    else:
+        raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column))
         
-    def execute(self, source, dest, obj=None, child=None, clearkeys=None):
-        for rule in self.syncrules:
-            rule.execute(source, dest, obj, child, clearkeys)
-    
-    def source_changes(self, uowcommit, source):
-        for rule in self.syncrules:
-            if rule.source_changes(uowcommit, source):
-                return True
-        else:
-            return False
-            
-class SyncRule(object):
-    """An instruction indicating how to populate the objects on each
-    side of a relationship.
-
-    E.g. if table1 column A is joined against table2 column
-    B, and we are a one-to-many from table1 to table2, a syncrule
-    would say *take the A attribute from object1 and assign it to the
-    B attribute on object2*.
-    """
-
-    def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
-        self.source_mapper = source_mapper
-        self.source_column = source_column
-        self.issecondary = issecondary
-        self.dest_mapper = dest_mapper
-        self.dest_column = dest_column
-        
-        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper
-
-    def dest_primary_key(self):
-        # late-evaluating boolean since some syncs are created
-        # before the mapper has assembled pks
-        try:
-            return self._dest_primary_key
-        except AttributeError:
-            self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper._pks_by_table[self.dest_column.table] and not self.dest_mapper.allow_null_pks
-            return self._dest_primary_key
-    
-    def _raise_col_to_prop(self, isdest):
-        if isdest:
-            raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (self.dest_column, self.dest_mapper))
-        else:
-            raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column.  Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (self.source_column, self.source_mapper, self.dest_column))
-                
-    def source_changes(self, uowcommit, source):
-        try:
-            prop = self.source_mapper._get_col_to_prop(self.source_column)
-        except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(False)
-        (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True)
-        return bool(added and deleted)
-    
-    def update(self, dest, parent, child, old_prefix):
-        if self.issecondary is False:
-            source = parent
-        elif self.issecondary is True:
-            source = child
-        try:
-            oldvalue = self.source_mapper._get_committed_attr_by_column(source.obj(), self.source_column)
-            value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
-        except exceptions.UnmappedColumnError:
-            self._raise_col_to_prop(False)
-        dest[self.dest_column.key] = value
-        dest[old_prefix + self.dest_column.key] = oldvalue
-        
-    def execute(self, source, dest, parent, child, clearkeys):
-        # TODO: break the "dictionary" case into a separate method like 'update' above,
-        # reduce conditionals
-        if source is None:
-            if self.issecondary is False:
-                source = parent
-            elif self.issecondary is True:
-                source = child
-        if clearkeys or source is None:
-            value = None
-            clearkeys = True
-        else:
-            try:
-                value = self.source_mapper._get_state_attr_by_column(source, self.source_column)
-            except exceptions.UnmappedColumnError:
-                self._raise_col_to_prop(False)
-        if isinstance(dest, dict):
-            dest[self.dest_column.key] = value
-        else:
-            if clearkeys and self.dest_primary_key():
-                raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.state_str(dest)))
-
-            if logging.is_debug_enabled(self.logger):
-                self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.state_str(source), str(self.source_column), mapperutil.state_str(dest), str(self.dest_column), value))
-            try:
-                self.dest_mapper._set_state_attr_by_column(dest, self.dest_column, value)
-            except exceptions.UnmappedColumnError:
-                self._raise_col_to_prop(True)
-
-SyncRule.logger = logging.class_logger(SyncRule)
-

lib/sqlalchemy/schema.py

     def references(self, column):
         """Return True if this references the given column via a foreign key."""
         for fk in self.foreign_keys:
-            if fk.column is column:
+            if fk.references(column.table):
                 return True
         else:
             return False

lib/sqlalchemy/sql/util.py

-from sqlalchemy import exceptions, schema, topological, util
+from sqlalchemy import exceptions, schema, topological, util, sql
 from sqlalchemy.sql import expression, operators, visitors
 from itertools import chain
 
 """Utility functions that build upon SQL and Schema constructs."""
 
 def sort_tables(tables, reverse=False):
+    """sort a collection of Table objects in order of their foreign-key dependency."""
+    
     tuples = []
     class TVisitor(schema.SchemaVisitor):
         def visit_foreign_key(_self, fkey):
         return sequence
 
 def find_tables(clause, check_columns=False, include_aliases=False):
+    """locate Table objects within the given expression."""
+    
     tables = []
     kwargs = {}
     if include_aliases:
     return tables
 
 def find_columns(clause):
+    """locate Column objects within the given expression."""
+    
     cols = util.Set()
     def visit_column(col):
         cols.add(col)
 
     return expression.ColumnSet(columns.difference(omit))
 
+def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_referenced_keys=None, any_operator=False):
+    """traverse an expression and locate binary criterion pairs."""
+    
+    if consider_as_foreign_keys and consider_as_referenced_keys:
+        raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'")
+        
+    def visit_binary(binary):
+        if not any_operator and binary.operator != operators.eq:
+            return
+        if not isinstance(binary.left, sql.ColumnElement) or not isinstance(binary.right, sql.ColumnElement):
+            return
+
+        if consider_as_foreign_keys:
+            if binary.left in consider_as_foreign_keys:
+                pairs.append((binary.right, binary.left))
+            elif binary.right in consider_as_foreign_keys:
+                pairs.append((binary.left, binary.right))
+        elif consider_as_referenced_keys:
+            if binary.left in consider_as_referenced_keys:
+                pairs.append((binary.left, binary.right))
+            elif binary.right in consider_as_referenced_keys:
+                pairs.append((binary.right, binary.left))
+        else:
+            if isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column):
+                if binary.left.references(binary.right):
+                    pairs.append((binary.right, binary.left))
+                elif binary.right.references(binary.left):
+                    pairs.append((binary.left, binary.right))
+    pairs = []
+    visitors.traverse(expression, visit_binary=visit_binary)
+    return pairs
+    
 class AliasedRow(object):
     
     def __init__(self, row, map):
         return self.row.keys()
 
 def row_adapter(from_, equivalent_columns=None):
-    """create a row adapter against a selectable."""
+    """create a row adapter callable against a selectable."""
     
     if equivalent_columns is None:
         equivalent_columns = {}

test/orm/collection.py

         collection_class = lambda: Ordered2(lambda v: (v.a, v.b))
         self._test_composite_mapped(collection_class)
 
+# TODO: are these tests redundant vs. the above tests ?
+# remove if so
+class CustomCollectionsTest(ORMTest):
+    def define_tables(self, metadata):
+        global sometable, someothertable
+        sometable = Table('sometable', metadata,
+            Column('col1',Integer, primary_key=True),
+            Column('data', String(30)))
+        someothertable = Table('someothertable', metadata,
+            Column('col1', Integer, primary_key=True),
+            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
+            Column('data', String(20))
+        )
+    def test_basic(self):
+        class MyList(list):
+            pass
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=MyList)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        assert isinstance(f.bars, MyList)
+        
+    def test_lazyload(self):
+        """test that a 'set' can be used as a collection and can lazyload."""
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=set)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.add(Bar())
+        f.bars.add(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+        f.bars.clear()
+
+    def test_dict(self):
+        """test that a 'dict' can be used as a collection and can lazyload."""
+
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+        class AppenderDict(dict):
+            @collection.appender
+            def set(self, item):
+                self[id(item)] = item
+            @collection.remover
+            def remove(self, item):
+                if id(item) in self:
+                    del self[id(item)]
+
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar, collection_class=AppenderDict)
+        })
+        mapper(Bar, someothertable)
+        f = Foo()
+        f.bars.set(Bar())
+        f.bars.set(Bar())
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+        f.bars.clear()
+
+    def test_dict_wrapper(self):
+        """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
+
+        class Foo(object):
+            pass
+        class Bar(object):
+            def __init__(self, data): self.data = data
+
+        mapper(Foo, sometable, properties={
+            'bars':relation(Bar,
+                collection_class=collections.column_mapped_collection(someothertable.c.data))
+        })
+        mapper(Bar, someothertable)
+
+        f = Foo()
+        col = collections.collection_adapter(f.bars)
+        col.append_with_event(Bar('a'))
+        col.append_with_event(Bar('b'))
+        sess = create_session()
+        sess.save(f)
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+
+        existing = set([id(b) for b in f.bars.values()])
+
+        col = collections.collection_adapter(f.bars)
+        col.append_with_event(Bar('b'))
+        f.bars['a'] = Bar('a')
+        sess.flush()
+        sess.clear()
+        f = sess.query(Foo).get(f.col1)
+        assert len(list(f.bars)) == 2
+
+        replaced = set([id(b) for b in f.bars.values()])
+        self.assert_(existing != replaced)
+
+    def test_list(self):
+        class Parent(object):
+            pass
+        class Child(object):
+            pass
+
+        mapper(Parent, sometable, properties={
+            'children':relation(Child, collection_class=list)
+        })
+        mapper(Child, someothertable)
+
+        control = list()
+        p = Parent()
+
+        o = Child()
+        control.append(o)
+        p.children.append(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control.extend(o)
+        p.children.extend(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control[0] == p.children[0]
+        assert control[-1] == p.children[-1]
+        assert control[1:3] == p.children[1:3]
+
+        del control[1]
+        del p.children[1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child()]
+        control[1:3] = o
+        p.children[1:3] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[1:3] = o
+        p.children[1:3] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[-1:-2] = o
+        p.children[-1:-2] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = [Child(), Child(), Child(), Child()]
+        control[4:] = o
+        p.children[4:] = o
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(0, o)
+        p.children.insert(0, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(3, o)
+        p.children.insert(3, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(999, o)
+        p.children.insert(999, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[0:1]
+        del p.children[0:1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[1:1]
+        del p.children[1:1]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[1:3]
+        del p.children[1:3]
+        assert control == p.children
+        assert control == list(p.children)
+
+        del control[7:]
+        del p.children[7:]
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop() == p.children.pop()
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop(0) == p.children.pop(0)
+        assert control == p.children
+        assert control == list(p.children)
+
+        assert control.pop(2) == p.children.pop(2)
+        assert control == p.children
+        assert control == list(p.children)
+
+        o = Child()
+        control.insert(2, o)
+        p.children.insert(2, o)
+        assert control == p.children
+        assert control == list(p.children)
+
+        control.remove(o)
+        p.children.remove(o)
+        assert control == p.children
+        assert control == list(p.children)
+
+    def test_custom(self):
+        class Parent(object):
+            pass
+        class Child(object):
+            pass
+
+        class MyCollection(object):
+            def __init__(self):
+                self.data = []
+            @collection.appender
+            def append(self, value):
+                self.data.append(value)
+            @collection.remover
+            def remove(self, value):
+                self.data.remove(value)
+            @collection.iterator
+            def __iter__(self):
+                return iter(self.data)
+
+        mapper(Parent, sometable, properties={
+            'children':relation(Child, collection_class=MyCollection)
+        })
+        mapper(Child, someothertable)
+
+        control = list()
+        p1 = Parent()
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        o = Child()
+        control.append(o)
+        p1.children.append(o)
+        assert control == list(p1.children)
+
+        sess = create_session()
+        sess.save(p1)
+        sess.flush()
+        sess.clear()
+
+        p2 = sess.query(Parent).get(p1.col1)
+        o = list(p2.children)
+        assert len(o) == 3
+
 if __name__ == "__main__":
     testenv.main()

test/orm/inheritance/polymorph2.py

             pass
         class Manager(Person):
             pass
-
-        mapper(Person, people, properties={
-            'manager':relation(Manager, primaryjoin=people.c.manager_id==managers.c.person_id, uselist=False)
-        })
-        mapper(Manager, managers, inherits=Person, inherit_condition=people.c.person_id==managers.c.person_id)
-
-        self.assertRaisesMessage(exceptions.ArgumentError, 
-            r"Can't determine relation direction for relationship 'Person\.manager \(Manager\)' - foreign key columns are present in both the parent and the child's mapped tables\.  Specify 'foreign_keys' argument\.",
-            compile_mappers
-        )
-        clear_mappers()
-
+        
+        # note that up until recently (0.4.4), we had to specify "foreign_keys" here
+        # for this primary join.  
         mapper(Person, people, properties={
             'manager':relation(Manager, primaryjoin=(people.c.manager_id ==
                                                      managers.c.person_id),
-                               foreign_keys=[people.c.manager_id],
                                uselist=False, post_update=True)
         })
         mapper(Manager, managers, inherits=Person,
                inherit_condition=people.c.person_id==managers.c.person_id)
-
+        
+        self.assertEquals(class_mapper(Person).get_property('manager').foreign_keys, set([people.c.manager_id]))
+        
         session = create_session()
         p = Person(name='some person')
         m = Manager(name='some manager')

test/orm/relationships.py

 from sqlalchemy.orm import collections
 from sqlalchemy.orm.collections import collection
 from testlib import *
+from testlib import fixtures
 
 class RelationTest(TestBase):
     """An extended topological sort test
 
         assert t3.count().scalar() == 1
 
-# TODO: move these tests to either attributes.py test or its own module
-class CustomCollectionsTest(ORMTest):
-    def define_tables(self, metadata):
-        global sometable, someothertable
-        sometable = Table('sometable', metadata,
-            Column('col1',Integer, primary_key=True),
-            Column('data', String(30)))
-        someothertable = Table('someothertable', metadata,
-            Column('col1', Integer, primary_key=True),
-            Column('scol1', Integer, ForeignKey(sometable.c.col1)),
-            Column('data', String(20))
-        )
-    def testbasic(self):
-        class MyList(list):
-            pass
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=MyList)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        assert isinstance(f.bars, MyList)
-    def testlazyload(self):
-        """test that a 'set' can be used as a collection and can lazyload."""
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=set)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        f.bars.add(Bar())
-        f.bars.add(Bar())
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-        f.bars.clear()
-
-    def testdict(self):
-        """test that a 'dict' can be used as a collection and can lazyload."""
-
-        class Foo(object):
-            pass
-        class Bar(object):
-            pass
-        class AppenderDict(dict):
-            @collection.appender
-            def set(self, item):
-                self[id(item)] = item
-            @collection.remover
-            def remove(self, item):
-                if id(item) in self:
-                    del self[id(item)]
-
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar, collection_class=AppenderDict)
-        })
-        mapper(Bar, someothertable)
-        f = Foo()
-        f.bars.set(Bar())
-        f.bars.set(Bar())
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-        f.bars.clear()
-
-    def testdictwrapper(self):
-        """test that the supplied 'dict' wrapper can be used as a collection and can lazyload."""
-
-        class Foo(object):
-            pass
-        class Bar(object):
-            def __init__(self, data): self.data = data
-
-        mapper(Foo, sometable, properties={
-            'bars':relation(Bar,
-                collection_class=collections.column_mapped_collection(someothertable.c.data))
-        })
-        mapper(Bar, someothertable)
-
-        f = Foo()
-        col = collections.collection_adapter(f.bars)
-        col.append_with_event(Bar('a'))
-        col.append_with_event(Bar('b'))
-        sess = create_session()
-        sess.save(f)
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-
-        existing = set([id(b) for b in f.bars.values()])
-
-        col = collections.collection_adapter(f.bars)
-        col.append_with_event(Bar('b'))
-        f.bars['a'] = Bar('a')
-        sess.flush()
-        sess.clear()
-        f = sess.query(Foo).get(f.col1)
-        assert len(list(f.bars)) == 2
-
-        replaced = set([id(b) for b in f.bars.values()])
-        self.assert_(existing != replaced)
-
-    def testlist(self):
-        class Parent(object):
-            pass
-        class Child(object):
-            pass
-
-        mapper(Parent, sometable, properties={
-            'children':relation(Child, collection_class=list)
-        })
-        mapper(Child, someothertable)
-
-        control = list()
-        p = Parent()
-
-        o = Child()
-        control.append(o)
-        p.children.append(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control.extend(o)
-        p.children.extend(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control[0] == p.children[0]
-        assert control[-1] == p.children[-1]
-        assert control[1:3] == p.children[1:3]
-
-        del control[1]
-        del p.children[1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child()]
-        control[1:3] = o
-        p.children[1:3] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[1:3] = o
-        p.children[1:3] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[-1:-2] = o
-        p.children[-1:-2] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = [Child(), Child(), Child(), Child()]
-        control[4:] = o
-        p.children[4:] = o
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(0, o)
-        p.children.insert(0, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(3, o)
-        p.children.insert(3, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(999, o)
-        p.children.insert(999, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[0:1]
-        del p.children[0:1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[1:1]
-        del p.children[1:1]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[1:3]
-        del p.children[1:3]
-        assert control == p.children
-        assert control == list(p.children)
-
-        del control[7:]
-        del p.children[7:]
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop() == p.children.pop()
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop(0) == p.children.pop(0)
-        assert control == p.children
-        assert control == list(p.children)
-
-        assert control.pop(2) == p.children.pop(2)
-        assert control == p.children
-        assert control == list(p.children)
-
-        o = Child()
-        control.insert(2, o)
-        p.children.insert(2, o)
-        assert control == p.children
-        assert control == list(p.children)
-
-        control.remove(o)
-        p.children.remove(o)
-        assert control == p.children
-        assert control == list(p.children)
-
-    def testobj(self):
-        class Parent(object):
-            pass
-        class Child(object):
-            pass
-
-        class MyCollection(object):
-            def __init__(self):
-                self.data = []
-            @collection.appender
-            def append(self, value):
-                self.data.append(value)
-            @collection.remover
-            def remove(self, value):
-                self.data.remove(value)
-            @collection.iterator
-            def __iter__(self):
-                return iter(self.data)
-
-        mapper(Parent, sometable, properties={
-            'children':relation(Child, collection_class=MyCollection)
-        })
-        mapper(Child, someothertable)
-
-        control = list()
-        p1 = Parent()
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        o = Child()
-        control.append(o)
-        p1.children.append(o)
-        assert control == list(p1.children)
-
-        sess = create_session()
-        sess.save(p1)
-        sess.flush()
-        sess.clear()
-
-        p2 = sess.query(Parent).get(p1.col1)
-        o = list(p2.children)
-        assert len(o) == 3
+        
+    
 
 class ViewOnlyTest(ORMTest):
     """test a view_only mapping where a third table is pulled into the primary join condition,
         assert set([x.t2id for x in c1.t2s]) == set([c2a.t2id, c2b.t2id])
         assert set([x.t2id for x in c1.t2_view]) == set([c2b.t2id])
 
+class ViewOnlyTest3(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+
+    def test_viewonly_join(self):
+        class Foo(fixtures.Base):
+            pass
+        class Bar(fixtures.Base):
+            pass
+
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=[bars.c.fid], viewonly=True)
+        })
+
+        mapper(Bar, bars)
+
+        sess = create_session()
+        sess.save(Foo(id=4))
+        sess.save(Foo(id=9))
+        sess.save(Bar(id=1, fid=2))
+        sess.save(Bar(id=2, fid=3))
+        sess.save(Bar(id=3, fid=6))
+        sess.save(Bar(id=4, fid=7))
+        sess.flush()
+
+        sess = create_session()
+        self.assertEquals(sess.query(Foo).filter_by(id=4).one(), Foo(id=4, bars=[Bar(fid=2), Bar(fid=3)]))
+        self.assertEquals(sess.query(Foo).filter_by(id=9).one(), Foo(id=9, bars=[Bar(fid=2), Bar(fid=3), Bar(fid=6), Bar(fid=7)]))
+
+class InvalidRelationEscalationTest(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars, Foo, Bar
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True), Column('fid', Integer))
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+            
+    def test_no_join(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_no_join_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+        
+    def test_no_equated(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_fks(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, foreign_keys=bars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, foreign_keys=[foos.c.fid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_viewonly(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id>bars.c.fid, viewonly=True)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_no_equated_self_ref_viewonly(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True)
+        })
+
+        mapper(Bar, bars)
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Specify the foreign_keys argument to indicate which columns on the relation are foreign.", compile_mappers)
+
+    def test_no_equated_self_ref_viewonly_fks(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id>foos.c.fid, viewonly=True, foreign_keys=[foos.c.fid])
+        })
+        compile_mappers()
+        self.assertEquals(Foo.foos.property.equated_pairs, [(foos.c.id, foos.c.fid)])
+
+    def test_equated(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, primaryjoin=foos.c.id==bars.c.fid)
+        })
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+    
+    def test_equated_self_ref(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid)
+        })
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_equated_self_ref_wrong_fks(self):
+        mapper(Foo, foos, properties={
+            'foos':relation(Foo, primaryjoin=foos.c.id==foos.c.fid, foreign_keys=[bars.c.id])
+        })
+
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+class InvalidRelationEscalationTestM2M(ORMTest):
+    def define_tables(self, metadata):
+        global foos, bars, Foo, Bar, foobars
+        foos = Table('foos', metadata, Column('id', Integer, primary_key=True))
+        foobars = Table('foobars', metadata, Column('fid', Integer), Column('bid', Integer))
+        bars = Table('bars', metadata, Column('id', Integer, primary_key=True))
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+
+    def test_no_join(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_no_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine join condition between parent/child tables on relation", compile_mappers)
+
+    def test_bad_primaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id>foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id)
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for primaryjoin condition", compile_mappers)
+
+    def test_bad_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not determine relation direction for secondaryjoin condition", compile_mappers)
+
+    def test_no_equated_secondaryjoin(self):
+        mapper(Foo, foos, properties={
+            'bars':relation(Bar, secondary=foobars, primaryjoin=foos.c.id==foobars.c.fid, secondaryjoin=foobars.c.bid<=bars.c.id, foreign_keys=[foobars.c.fid, foobars.c.bid])
+        })
+
+        mapper(Bar, bars)
+        self.assertRaisesMessage(exceptions.ArgumentError, "Could not locate any equated column pairs for secondaryjoin condition", compile_mappers)
+
 
 if __name__ == "__main__":
     testenv.main()
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.