Commits

Mike Bayer committed 447fe7f

Comments (0)

Files changed (1)

lib/sqlalchemy/mapper.py

 import sqlalchemy.objectstore as objectstore
 import random, copy, types
 
-__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 
-        'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql', 'extension', 'MapperExtension']
+__ALL__ = ['relation', 'eagerload', 'lazyload', 'noload', 'assignmapper', 
+        'mapper', 'clear_mappers', 'objectstore', 'sql', 'extension', 'class_mapper', 'object_mapper', 'MapperExtension']
 
 def relation(*args, **params):
     """provides a relationship of a primary Mapper to a secondary Mapper, which corresponds
     else:
         return _relation_mapper(*args, **params)
 
-def _relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin = None, lazy = True, **kwargs):
+def _relation_loader(mapper, secondary=None, primaryjoin=None, secondaryjoin=None, lazy=True, **kwargs):
     if lazy:
         return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
     elif lazy is None:
         return PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
     else:
         return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
-    
-def _relation_mapper(class_, table=None, secondary=None, primaryjoin=None, secondaryjoin=None, **kwargs):
-    return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, **kwargs)
+
+def _relation_mapper(class_, table=None, secondary=None, 
+                    primaryjoin=None, secondaryjoin=None, 
+                    foreignkey=None, uselist=None, private=False, live=False, association=None, lazy=True, **kwargs):
+
+    return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, 
+                    foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association, lazy=lazy)
+
+#def _relation_mapper(class_, table=None, secondary=None, 
+#                    primaryjoin=None, secondaryjoin=None, foreignkey=None, 
+#                    uselist=None, private=False, live=False, association=None, **kwargs):
+#    return _relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, foreignkey=foreignkey, uselist=uselist, private=private, live=live, association=association)
 
 class assignmapper(object):
     """provides a property object that will instantiate a Mapper for a given class the first
     def instance_key(self, instance):
         return self.identity_key(*[self._getattrbycolumn(instance, column) for column in self.primary_keys[self.table]])
 
-#    def _primary_key_ident(self, obj):
-#        """returns an identity of an object based on its primary keys, across all tables 
-#        represented by this mapper."""
-#        res = []
-#        for table in self.tables:
-#            for k in self.primary_keys[table]:
-#                res.append(self._getattrbycolumn(obj, k))
-#        return tuple(res)
-
     def compile(self, whereclause = None, **options):
         """works like select, except returns the SQL statement object without 
         compiling or executing it"""
         """called by a UnitOfWork object to save objects, which involves either an INSERT or
         an UPDATE statement for each table used by this mapper, for each element of the
         list."""
-                
+          
         for table in self.tables:
             # loop thru tables in the outer loop, objects on the inner loop.
             # this is important for an object represented across two tables
             # second table.
             insert = []
             update = []
+            
+            # we have our own idea of the primary keys 
+            # for this table, in the case that the user
+            # specified custom primary keys.
+            pk = {}
+            for k in self.primary_keys[table]:
+                pk[k] = k
             for obj in objects:
                 
 #                print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
                 params = {}
 
                 for col in table.columns:
-                    if col.primary_key:
+                    #if col.primary_key:
+                    if pk.has_key(col):
                         if hasattr(obj, "_instance_key"):
                             params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
                         else:
             
     def _compile(self, whereclause = None, order_by = None, **options):
         statement = sql.select([self.table], whereclause, order_by = order_by)
-        statement.order_by(self.primarytable.rowid_column)
+        statement.order_by(self.table.rowid_column)
         # plugin point
         for key, value in self.props.iteritems():
             value.setup(key, statement, **options) 
         return instance
 
         
-class MapperProperty:
+class MapperProperty(object):
     """an element attached to a Mapper that describes and assists in the loading and saving 
     of an attribute on an object instance."""
     def execute(self, instance, row, identitykey, imap, isnew):
 
     """describes an object property that holds a single item or list of items that correspond
     to a related database table."""
-    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, isassociation=False, **kwargs):
+    def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey=None, uselist=None, private=False, live=False, isoption=False, association=None, **kwargs):
         self.uselist = uselist
         self.argument = argument
         self.secondary = secondary
         self.private = private
         self.live = live
         self.isoption = isoption
-        self.isassociation = isassociation
+        self.association = association
         self._hash_key = "%s(%s, %s, %s, %s, %s, %s, %s)" % (self.__class__.__name__, hash_key(self.argument), hash_key(secondary), hash_key(primaryjoin), hash_key(secondaryjoin), hash_key(foreignkey), repr(uselist), repr(private))
 
     def _copy(self):
             self.mapper = class_mapper(self.argument)
         else:
             self.mapper = self.argument
-            
+
+        if self.association is not None:
+            if isinstance(self.association, type):
+                self.association = class_mapper(self.association)
+                
         self.target = self.mapper.table
         self.key = key
         self.parent = parent
 
             
     def register_dependencies(self, uowcommit):
-        if self.direction == PropertyLoader.CENTER:
+        if self.association is not None:
+            uowcommit.register_dependency(self.parent, self.mapper)
+            uowcommit.register_dependency(self.association, self.parent)
+            uowcommit.register_processor(self.parent, self, self.parent, False)
+            uowcommit.register_processor(self.parent, self, self.parent, True)
+        elif self.direction == PropertyLoader.CENTER:
             # with many-to-many, set the parent as dependent on us, then the 
             # list of associations as dependent on the parent
             # if only a list changes, the parent mapper is the only mapper that
                     self._synchronize(obj, child, None, True)
                     uowcommit.register_object(child)
                 uowcommit.register_deleted_list(childlist)
-        elif self.isassociation:
+        elif self.association is not None:
+            # TODO: this is new code, for managing "association objects".
+            # its probably glitchy.
             for obj in deplist:
                 childlist = getlist(obj, passive=True)
                 if childlist is None: continue
                 uowcommit.register_saved_list(childlist)
-
-                # TODO: sort out the association objects so that we only insert/delete/update those
-                # that are actually correct.
+                
+                d = {}
                 for child in childlist:
                     self._synchronize(obj, child, None, False)
-                    uowcommit.unregister_object(child)                    
+                    key = self.mapper.instance_key(child)
+                    d[key] = child
+                    uowcommit.unregister_object(child)
 
+                for child in childlist.added_items():
+                    uowcommit.register_object(child)
+                    key = self.mapper.instance_key(child)
+                    d[key] = child
+                    
+                for child in childlist.unchanged_items():
+                    key = self.mapper.instance_key(child)
+                    o = d[key]
+                    o._instance_key= key
+                    
                 for child in childlist.deleted_items():
-                    uowcommit.unregister_object(child)                    
-                    
+                    key = self.mapper.instance_key(child)
+                    if d.has_key(key):
+                        o = d[key]
+                        o._instance_key = key
+                        uowcommit.unregister_object(child)
+                    else:
+                        uowcommit.register_object(child, isdelete=True)
         else:
             for obj in deplist:
-                #print "PROCESS:", repr(obj)
                 if self.direction == PropertyLoader.RIGHT:
                     uowcommit.register_object(obj)
                 childlist = getlist(obj, passive=True)
                 if childlist is None: continue
                 uowcommit.register_saved_list(childlist)
                 for child in childlist.added_items():
-                    #print "parent", repr(obj), "child", repr(child), "EOF"
                     self._synchronize(obj, child, None, False)
                     if self.direction == PropertyLoader.LEFT:
                         uowcommit.register_object(child)
 def create_lazy_clause(table, primaryjoin, secondaryjoin, foreignkey):
     binds = {}
     def visit_binary(binary):
-        circular = binary.left.table is binary.right.table
+        circular = isinstance(binary.left, schema.Column) and isinstance(binary.right, schema.Column) and binary.left.table is binary.right.table
         if isinstance(binary.left, schema.Column) and ((not circular and binary.left.table is table) or (circular and foreignkey is binary.right)):
             binary.left = binds.setdefault(binary.left,
                     sql.BindParamClause(table.name + "_" + binary.left.name, None, shortname = binary.left.name))