Commits

Mike Bayer committed 35e8e02

factored out "syncrule" logic to a separate package, so mapper will be able to make use of it as well as properties. also clarifies the "synchronization" idea

Comments (0)

Files changed (4)

lib/sqlalchemy/mapping/objectstore.py

 
     def register_callable(self, obj, key, func, uselist, **kwargs):
         self.attributes.set_callable(obj, key, func, uselist, **kwargs)
-        
+    
     def register_clean(self, obj):
         try:
             del self.dirty[obj]

lib/sqlalchemy/mapping/properties.py

 import sqlalchemy.engine as engine
 import sqlalchemy.util as util
 import sqlalchemy.attributes as attributes
+import sync
 import mapper
 import objectstore
 from sqlalchemy.exceptions import *
             if self.backref is not None:
                 # try to set a LazyLoader on our mapper referencing the parent mapper
                 if not self.mapper.props.has_key(self.backref):
-                    self.mapper.add_property(self.backref, LazyLoader(self.parent, self.secondary, self.primaryjoin, self.secondaryjoin, backref=self.key, is_backref=True));
+                    if self.secondaryjoin is not None:
+                        # if setting up a backref to a many-to-many, reverse the order
+                        # of the "primary" and "secondary" joins
+                        pj = self.secondaryjoin
+                        sj = self.primaryjoin
+                    else:
+                        pj = self.primaryjoin
+                        sj = None
+                    self.mapper.add_property(self.backref, LazyLoader(self.parent, self.secondary, pj, sj, backref=self.key, is_backref=True));
                 else:
                     # else set one of us as the "backreference"
                     if not self.mapper.props[self.backref].is_backref:
         The list of rules is used within commits by the _synchronize() method when dependent 
         objects are processed."""
 
-        SyncRule = PropertyLoader.SyncRule
-
         parent_tables = util.HashSet(self.parent.tables + [self.parent.primarytable])
         target_tables = util.HashSet(self.mapper.tables + [self.mapper.primarytable])
 
-        def check_for_table(binary, l):
-            for col in [binary.left, binary.right]:
-                if col.table in l:
-                    return col
-            else:
-                return None
-        
-        def compile(binary):
-            """assembles a SyncRule given a single binary condition"""
-            if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
-                return
-
-            if binary.left.table == binary.right.table:
-                # self-cyclical relation
-                if binary.left.primary_key:
-                    source = binary.left
-                    dest = binary.right
-                elif binary.right.primary_key:
-                    source = binary.right
-                    dest = binary.left
-                else:
-                    raise ArgumentError("Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname))
-                if self.direction == PropertyLoader.ONETOMANY:
-                    self.syncrules.append(SyncRule(self.parent, source, dest, dest_mapper=self.mapper))
-                elif self.direction == PropertyLoader.MANYTOONE:
-                    self.syncrules.append(SyncRule(self.mapper, source, dest, dest_mapper=self.parent))
-                else:
-                    raise AssertionError("assert failed")
-            else:
-                pt = check_for_table(binary, parent_tables)
-                tt = check_for_table(binary, target_tables)
-                st = check_for_table(binary, [self.secondary])
-                #print "parenttable", [t.name for t in parent_tables]
-                #print "ttable", [t.name for t in target_tables]
-                #print "OK", str(binary), pt, tt, st
-                if pt and tt:
-                    if self.direction == PropertyLoader.ONETOMANY:
-                        self.syncrules.append(SyncRule(self.parent, pt, tt, dest_mapper=self.mapper))
-                    elif self.direction == PropertyLoader.MANYTOONE:
-                        self.syncrules.append(SyncRule(self.mapper, tt, pt, dest_mapper=self.parent))
-                    else:
-                        if visiting is self.primaryjoin:
-                            self.syncrules.append(SyncRule(self.parent, pt, st, direction=PropertyLoader.ONETOMANY))
-                        else:
-                            self.syncrules.append(SyncRule(self.mapper, tt, st, direction=PropertyLoader.MANYTOONE))
-                elif pt and st:
-                    self.syncrules.append(SyncRule(self.parent, pt, st, direction=PropertyLoader.ONETOMANY))
-                elif tt and st:
-                    self.syncrules.append(SyncRule(self.mapper, tt, st, direction=PropertyLoader.MANYTOONE))
-
-        self.syncrules = []
-        processor = BinaryVisitor(compile)
-        visiting = self.primaryjoin
-        self.primaryjoin.accept_visitor(processor)
-        if self.secondaryjoin is not None:
-            visiting = self.secondaryjoin
-            self.secondaryjoin.accept_visitor(processor)
-        if len(self.syncrules) == 0:
-            raise ArgumentError("No syncrules generated for join criterion " + str(self.primaryjoin))
+        self.syncrules = sync.ClauseSynchronizer(self.parent, self.mapper, self.direction)
+        if self.direction == PropertyLoader.MANYTOMANY:
+            #print "COMPILING p/c", self.parent, self.mapper
+            self.syncrules.compile(self.primaryjoin, parent_tables, [self.secondary], False)
+            self.syncrules.compile(self.secondaryjoin, target_tables, [self.secondary], True)
+        else:
+            self.syncrules.compile(self.primaryjoin, parent_tables, target_tables)
 
     def _synchronize(self, obj, child, associationrow, clearkeys):
         """called during a commit to execute the full list of syncrules on the 
         if dest is None:
             return
 
-        for rule in self.syncrules:
-            rule.execute(source, dest, obj, child, clearkeys)
-
-    class SyncRule(object):
-        """An instruction indicating how to populate the objects on each side of a relationship.  
-        i.e. if table1 column A is joined against
-        table2 column B, and we are a one-to-many from table1 to table2, a syncrule would say 
-        'take the A attribute from object1 and assign it to the B attribute on object2'.  
-        
-        A rule contains the source mapper, the source column, destination column, 
-        destination mapper in the case of a one/many relationship, and
-        the integer direction of this mapper relative to the association in the case
-        of a many to many relationship.
-        """
-        def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, direction=None):
-            self.source_mapper = source_mapper
-            self.source_column = source_column
-            self.direction = direction
-            self.dest_mapper = dest_mapper
-            self.dest_column = dest_column
-            #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper, direction
-
-        def execute(self, source, dest, obj, child, clearkeys):
-            if self.direction is not None:
-                self.exec_many2many(dest, obj, child, clearkeys)
-            else:
-                self.exec_one2many(source, dest, clearkeys)
-
-        def exec_many2many(self, destination, obj, child, clearkeys):
-            if self.direction == PropertyLoader.ONETOMANY:
-                source = obj
-            elif self.direction == PropertyLoader.MANYTOONE:
-                source = child
-            if clearkeys:
-                value = None
-            else:
-                value = self.source_mapper._getattrbycolumn(source, self.source_column)
-            destination[self.dest_column.key] = value
-            
-        def exec_one2many(self, source, destination, clearkeys):
-            if clearkeys or source is None:
-                value = None
-            else:
-                value = self.source_mapper._getattrbycolumn(source, self.source_column)
-            #print "SYNC VALUE", value, "TO", destination
-            self.dest_mapper._setattrbycolumn(destination, self.dest_column, value)
-                
+        self.syncrules.execute(source, dest, obj, child, clearkeys)
 
 class LazyLoader(PropertyLoader):
     def do_init_subclass(self, key, parent):

lib/sqlalchemy/mapping/sync.py

+import sqlalchemy.sql as sql
+import sqlalchemy.schema as schema
+from sqlalchemy.exceptions import *
+import properties
+
+"""contains the ClauseSynchronizer class which is used to map attributes between two objects
+in a manner corresponding to a SQL clause that compares column values."""
+
+ONETOMANY = 0
+MANYTOONE = 1
+MANYTOMANY = 2
+
+class ClauseSynchronizer(object):
+    """Given a SQL clause, usually a series of one or more binary 
+    expressions between columns, and a set of 'source' and 'destination' mappers, compiles a set of SyncRules
+    corresponding to that information.  The ClauseSynchronizer can then be executed given a set of parent/child 
+    objects or destination dictionary, which will iterate through each of its SyncRules and execute them.
+    Each SyncRule will copy the value of a single attribute from the parent
+    to the child, corresponding to the pair of columns in a particular binary expression, using the source and
+    destination mappers to map those two columns to object attributes within parent and child."""
+    def __init__(self, parent_mapper, child_mapper, direction):
+        self.parent_mapper = parent_mapper
+        self.child_mapper = child_mapper
+        self.direction = direction
+        self.syncrules = []
+
+    def compile(self, sqlclause, source_tables, target_tables, issecondary=None):
+        def check_for_table(binary, l):
+            for col in [binary.left, binary.right]:
+                if col.table in l:
+                    return col
+            else:
+                return None
+        
+        def compile_binary(binary):
+            """assembles a SyncRule given a single binary condition"""
+            if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
+                return
+
+            if binary.left.table == binary.right.table:
+                # self-cyclical relation
+                if binary.left.primary_key:
+                    source = binary.left
+                    dest = binary.right
+                elif binary.right.primary_key:
+                    source = binary.right
+                    dest = binary.left
+                else:
+                    raise ArgumentError("Cant determine direction for relationship %s = %s" % (binary.left.fullname, binary.right.fullname))
+                if self.direction == ONETOMANY:
+                    self.syncrules.append(SyncRule(self.parent_mapper, source, dest, dest_mapper=self.child_mapper))
+                elif self.direction == MANYTOONE:
+                    self.syncrules.append(SyncRule(self.child_mapper, source, dest, dest_mapper=self.parent_mapper))
+                else:
+                    raise AssertionError("assert failed")
+            else:
+                pt = check_for_table(binary, source_tables)
+                tt = check_for_table(binary, target_tables)
+                #print "OK", binary, [t.name for t in source_tables], [t.name for t in target_tables]
+                if pt and tt:
+                    if self.direction == ONETOMANY:
+                        self.syncrules.append(SyncRule(self.parent_mapper, pt, tt, dest_mapper=self.child_mapper))
+                    elif self.direction == MANYTOONE:
+                        self.syncrules.append(SyncRule(self.child_mapper, tt, pt, dest_mapper=self.parent_mapper))
+                    else:
+                        if not issecondary:
+                            self.syncrules.append(SyncRule(self.parent_mapper, pt, tt, dest_mapper=self.child_mapper, issecondary=issecondary))
+                        else:
+                            self.syncrules.append(SyncRule(self.child_mapper, pt, tt, dest_mapper=self.parent_mapper, issecondary=issecondary))
+                            
+        rules_added = len(self.syncrules)
+        processor = BinaryVisitor(compile_binary)
+        sqlclause.accept_visitor(processor)
+        if len(self.syncrules) == rules_added:
+            raise ArgumentError("No syncrules generated for join criterion " + str(sqlclause))
+        
+    def execute(self, source, dest, obj, child, clearkeys):
+        for rule in self.syncrules:
+            rule.execute(source, dest, obj, child, clearkeys)
+        
+class SyncRule(object):
+    """An instruction indicating how to populate the objects on each side of a relationship.  
+    i.e. if table1 column A is joined against
+    table2 column B, and we are a one-to-many from table1 to table2, a syncrule would say 
+    'take the A attribute from object1 and assign it to the B attribute on object2'.  
+    
+    A rule contains the source mapper, the source column, destination column, 
+    destination mapper in the case of a one/many relationship, and
+    the integer direction of this mapper relative to the association in the case
+    of a many to many relationship.
+    """
+    def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
+        self.source_mapper = source_mapper
+        self.source_column = source_column
+        self.issecondary = issecondary
+        self.dest_mapper = dest_mapper
+        self.dest_column = dest_column
+        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper, direction
+
+    def execute(self, source, dest, obj, child, clearkeys):
+        if source is None:
+            if self.issecondary is False:
+                source = obj
+            elif self.issecondary is True:
+                source = child
+        if clearkeys or source is None:
+            value = None
+        else:
+            value = self.source_mapper._getattrbycolumn(source, self.source_column)
+        if isinstance(dest, dict):
+            dest[self.dest_column.key] = value
+        else:
+            #print "SYNC VALUE", value, "TO", dest
+            self.dest_mapper._setattrbycolumn(dest, self.dest_column, value)
+            
+class BinaryVisitor(sql.ClauseVisitor):
+    def __init__(self, func):
+        self.func = func
+    def visit_binary(self, binary):
+        self.func(binary)
+

test/inheritance.py

             objectstore.commit()
 
 class InheritTest2(testbase.AssertMixin):
-	def setUpAll(self):
-		engine = testbase.db
-		global foo, bar, foo_bar
-		foo = Table('foo', engine,
-		    Column('id', Integer, primary_key=True),
-			Column('data', String(20)),
-			).create()
+    def setUpAll(self):
+        engine = testbase.db
+        global foo, bar, foo_bar
+        foo = Table('foo', engine,
+            Column('id', Integer, primary_key=True),
+            Column('data', String(20)),
+            ).create()
 
-		bar = Table('bar', engine,
-		    Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
-		    #Column('fid', Integer, ForeignKey('foo.id'), )
-			).create()
+        bar = Table('bar', engine,
+            Column('bid', Integer, ForeignKey('foo.id'), primary_key=True),
+            #Column('fid', Integer, ForeignKey('foo.id'), )
+            ).create()
 
-		foo_bar = Table('foo_bar', engine,
-		    Column('foo_id', Integer, ForeignKey('foo.id')),
-		    Column('bar_id', Integer, ForeignKey('bar.bid'))).create()
+        foo_bar = Table('foo_bar', engine,
+            Column('foo_id', Integer, ForeignKey('foo.id')),
+            Column('bar_id', Integer, ForeignKey('bar.bid'))).create()
 
-	def tearDownAll(self):
-		foo_bar.drop()
-		bar.drop()
-		foo.drop()
+    def tearDownAll(self):
+        foo_bar.drop()
+        bar.drop()
+        foo.drop()
 
-	def testbasic(self):
-		class Foo(object): 
-			def __init__(self, data=None):
-				self.data = data
-			def __str__(self):
-				return "Foo(%s)" % self.data
-			def __repr__(self):
-				return str(self)
+    def testbasic(self):
+        class Foo(object): 
+            def __init__(self, data=None):
+                self.data = data
+            def __str__(self):
+                return "Foo(%s)" % self.data
+            def __repr__(self):
+                return str(self)
 
-		Foo.mapper = mapper(Foo, foo)
-		class Bar(Foo):
-			def __str__(self):
-				return "Bar(%s)" % self.data
+        Foo.mapper = mapper(Foo, foo)
+        class Bar(Foo):
+            def __str__(self):
+                return "Bar(%s)" % self.data
 
-		Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties = {
-				# TODO: use syncrules for this
-				'id':[bar.c.bid, foo.c.id]
-			})
+        Bar.mapper = mapper(Bar, bar, inherits=Foo.mapper, properties = {
+                # TODO: use syncrules for this
+                'id':[bar.c.bid, foo.c.id]
+            })
 
-		Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, primaryjoin=bar.c.bid==foo_bar.c.bar_id, secondaryjoin=foo_bar.c.foo_id==foo.c.id, lazy=False))
-		#Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, lazy=False))
+        Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, primaryjoin=bar.c.bid==foo_bar.c.bar_id, secondaryjoin=foo_bar.c.foo_id==foo.c.id, lazy=False))
+        #Bar.mapper.add_property('foos', relation(Foo.mapper, foo_bar, lazy=False))
 
+        b = Bar('barfoo')
+        objectstore.commit()
 
-		b = Bar('barfoo')
-		objectstore.commit()
+        b.foos.append(Foo('subfoo1'))
+        b.foos.append(Foo('subfoo2'))
 
+        objectstore.commit()
+        objectstore.clear()
 
-		b.foos.append(Foo('subfoo1'))
-		b.foos.append(Foo('subfoo2'))
-
-		objectstore.commit()
-		objectstore.clear()
-
-		l =b.mapper.select()
-		print l[0]
-		print l[0].foos
+        l =b.mapper.select()
+        print l[0]
+        print l[0].foos
+        self.assert_result(l, Bar,
+#            {'id':1, 'data':'barfoo', 'bid':1, 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])},
+            {'id':1, 'data':'barfoo', 'foos':(Foo, [{'id':2,'data':'subfoo1'}, {'id':3,'data':'subfoo2'}])},
+            )
 
 
 if __name__ == "__main__":