Commits

Mike Bayer committed 61f5213

- reworked all lazy/deferred/expired callables to be
serializable class instances, added pickling tests
- cleaned up "deferred" polymorphic system so that the
mapper handles it entirely
- columns which are missing from a Query's select statement
now get automatically deferred during load.

Comments (0)

Files changed (13)

      have each method called only once per operation, use the same 
      instance of the extension for both mappers.
      [ticket:490]
+
+   - columns which are missing from a Query's select statement
+     now get automatically deferred during load.
      
+   - improved support for pickling of mapped entities.  Per-instance
+     lazy/deferred/expired callables are now serializable so that
+     they serialize and deserialize with _state. 
+       
    - new synonym() behavior: an attribute will be placed on the mapped
      class, if one does not exist already, in all cases. if a property
      already exists on the class, the synonym will decorate the property

lib/sqlalchemy/orm/attributes.py

 class ScalarAttributeImpl(AttributeImpl):
     """represents a scalar value-holding InstrumentedAttribute."""
 
-    accepts_global_callable = True
+    accepts_scalar_loader = True
     
     def delete(self, state):
         if self.key not in state.committed_state:
             state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE)
 
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
         state.modified=True
 
     Adds events to delete/set operations.
     """
 
-    accepts_global_callable = False
+    accepts_scalar_loader = False
 
     def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(ScalarObjectAttributeImpl, self).__init__(class_, key,
         
     def delete(self, state):
         old = self.get(state)
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
         self.fire_remove_event(state, old, self)
 
     CollectionAdapter, a "view" onto that object that presents consistent
     bag semantics to the orm layer independent of the user data implementation.
     """
-    accepts_global_callable = False
+    accepts_scalar_loader = False
     
     def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs):
         super(CollectionAttributeImpl, self).__init__(class_, 
 
         collection = self.get_collection(state)
         collection.clear_with_event()
+        # TODO: catch key errors, convert to attributeerror?
         del state.dict[self.key]
 
     def initialize(self, state):
         self.mappers = {}
         self.attrs = {}
         self.has_mutable_scalars = False
-        
+
 class InstanceState(object):
     """tracks state information at the instance level."""
 
         self.dict = obj.__dict__
         self.committed_state = {}
         self.modified = False
-        self.trigger = None
         self.callables = {}
         self.parents = {}
         self.pending = {}
             return None
             
     def __getstate__(self):
-        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj()}
+        return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':getattr(self, 'expired_attributes', None), 'callables':self.callables}
     
     def __setstate__(self, state):
         self.committed_state = state['committed_state']
         self.obj = weakref.ref(state['instance'])
         self.class_ = self.obj().__class__
         self.dict = self.obj().__dict__
-        self.callables = {}
-        self.trigger = None
-    
+        self.callables = state['callables']
+        self.runid = None
+        self.appenders = {}
+        if state['expired_attributes'] is not None:
+            self.expire_attributes(state['expired_attributes'])
+
     def initialize(self, key):
         getattr(self.class_, key).impl.initialize(self)
         
     def set_callable(self, key, callable_):
         self.dict.pop(key, None)
         self.callables[key] = callable_
-    
-    def __fire_trigger(self):
+
+    def __call__(self):
+        """__call__ allows the InstanceState to act as a deferred 
+        callable for loading expired attributes, which is also
+        serializable.
+        """
         instance = self.obj()
-        self.trigger(instance, [k for k in self.expired_attributes if k not in self.dict])
+        self.class_._class_state.deferred_scalar_loader(instance, [k for k in self.expired_attributes if k not in self.committed_state])
         for k in self.expired_attributes:
             self.callables.pop(k, None)
         self.expired_attributes.clear()
         return ATTR_WAS_SET
     
+    def unmodified(self):
+        """a set of keys which have no uncommitted changes"""
+
+        return util.Set([
+            attr.impl.key for attr in _managed_attributes(self.class_) if
+            attr.impl.key not in self.committed_state
+            and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self))
+        ])
+    unmodified = property(unmodified)
+    
     def expire_attributes(self, attribute_names):
         if not hasattr(self, 'expired_attributes'):
             self.expired_attributes = util.Set()
+            
         if attribute_names is None:
             for attr in _managed_attributes(self.class_):
                 self.dict.pop(attr.impl.key, None)
-                self.callables[attr.impl.key] = self.__fire_trigger
-                self.expired_attributes.add(attr.impl.key)
+
+                if attr.impl.accepts_scalar_loader:
+                    self.callables[attr.impl.key] = self
+                    self.expired_attributes.add(attr.impl.key)
+
             self.committed_state = {}
         else:
             for key in attribute_names:
                 self.dict.pop(key, None)
                 self.committed_state.pop(key, None)
 
-                if not getattr(self.class_, key).impl.accepts_global_callable:
-                    continue
-
-                self.callables[key] = self.__fire_trigger
-                self.expired_attributes.add(key)
+                if getattr(self.class_, key).impl.accepts_scalar_loader:
+                    self.callables[key] = self
+                    self.expired_attributes.add(key)
                 
     def reset(self, key):
         """remove the given attribute and any callables associated with it."""
     if not '_class_state' in class_.__dict__:
         class_._class_state = ClassState()
     
-def register_class(class_, extra_init=None, on_exception=None):
+def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None):
     # do a sweep first, this also helps some attribute extensions
     # (like associationproxy) become aware of themselves at the 
     # class level
         getattr(class_, key, None)
 
     _init_class_state(class_)
+    class_._class_state.deferred_scalar_loader=deferred_scalar_loader
     
     oldinit = None
     doinit = False

lib/sqlalchemy/orm/interfaces.py

 """
 from sqlalchemy import util, logging, exceptions
 from sqlalchemy.sql import expression
+from itertools import chain
 class_mapper = None
 
 __all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
         return prev + (mapper.base_mapper, key)
     else:
         return (mapper.base_mapper, key)
-        
+
+def serialize_path(path):
+    if path is None:
+        return None
+
+    return [
+        (mapper.class_, mapper.entity_name, key)
+        for mapper, key in [(path[i], path[i+1]) for i in range(0, len(path)-1, 2)]
+    ]
+    
+def deserialize_path(path):
+    if path is None:
+        return None
+
+    global class_mapper
+    if class_mapper is None:
+        from sqlalchemy.orm import class_mapper
+
+    return tuple(
+        chain(*[(class_mapper(cls, entity), key) for cls, entity, key in path])
+    )
 
 class MapperOption(object):
     """Describe a modification to a Query."""

lib/sqlalchemy/orm/mapper.py

         def on_exception(class_, oldinit, instance, args, kwargs):
             util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs)
 
-        attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception)
+        attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes)
         
         self._class_state = self.class_._class_state
         _mapper_registry[self] = True
             instance._sa_session_id = context.session.hash_key
             session_identity_map[identitykey] = instance
         
-        if currentload or context.populate_existing or self.always_refresh or state.trigger:
+        if currentload or context.populate_existing or self.always_refresh:
             if isnew:
                 state.runid = context.runid
-                state.trigger = None
                 context.progress.add(state)
-
+                
             if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
                 self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew)
-        
+
+        elif getattr(state, 'expired_attributes', None):
+            if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE:
+                self.populate_instance(context, instance, row, only_load_props=state.expired_attributes, instancekey=identitykey, isnew=isnew)
+            
         if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE):
             result.append(instance)
             
         return instance
-                
-    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
-        def visit_binary(binary):
-            leftcol = binary.left
-            rightcol = binary.right
-            if leftcol is None or rightcol is None:
-                return
-            if leftcol.table not in needs_tables:
-                binary.left = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((leftcol, binary.left))
-            elif rightcol not in needs_tables:
-                binary.right = sql.bindparam(None, None, type_=binary.right.type)
-                param_names.append((rightcol, binary.right))
-
-        allconds = []
-        param_names = []
-
-        for mapper in self.iterate_to_root():
-            if mapper is base_mapper:
-                break
-            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
-        
-        return sql.and_(*allconds), param_names
 
     def translate_row(self, tomapper, row):
         """Translate the column keys of a row into a new or proxied
             populators = new_populators
         else:
             populators = existing_populators
-                
+
+        if only_load_props:
+            populators = [p for p in populators if p[0] in only_load_props]
+            
         for (key, populator) in populators:
             selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags)
             
             p(state.obj())
 
     def _get_poly_select_loader(self, selectcontext, row):
-        # 'select' or 'union'+col not present
+        """set up attribute loaders for 'select' and 'deferred' polymorphic loading.
+        
+        this loading uses a second SELECT statement to load additional tables,
+        either immediately after loading the main table or via a deferred attribute trigger.
+        """
+        
         (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None))
-        if hosted_mapper is None or not needs_tables or hosted_mapper.polymorphic_fetch == 'deferred':
+        
+        if hosted_mapper is None or not needs_tables:
             return
         
         cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables)
         statement = sql.select(needs_tables, cond, use_labels=True)
-        def post_execute(instance, **flags):
-            if self.__should_log_debug:
-                self.__log_debug("Post query loading instance " + instance_str(instance))
+        
+        if hosted_mapper.polymorphic_fetch == 'select':
+            def post_execute(instance, **flags):
+                if self.__should_log_debug:
+                    self.__log_debug("Post query loading instance " + instance_str(instance))
 
-            identitykey = self.identity_key_from_instance(instance)
+                identitykey = self.identity_key_from_instance(instance)
 
-            params = {}
-            for c, bind in param_names:
-                params[bind] = self._get_attr_by_column(instance, c)
-            row = selectcontext.session.connection(self).execute(statement, params).fetchone()
-            self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+                params = {}
+                for c, bind in param_names:
+                    params[bind] = self._get_attr_by_column(instance, c)
+                row = selectcontext.session.connection(self).execute(statement, params).fetchone()
+                self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True)
+            return post_execute
+        elif hosted_mapper.polymorphic_fetch == 'deferred':
+            from sqlalchemy.orm.strategies import DeferredColumnLoader
+            
+            def post_execute(instance, **flags):
+                def create_statement(instance):
+                    params = {}
+                    for (c, bind) in param_names:
+                        # use the "committed" (database) version to get query column values
+                        params[bind] = self._get_committed_attr_by_column(instance, c)
+                    return (statement, params)
+                
+                props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__]
+                keys = [p.key for p in props]
+                for prop in props:
+                    strategy = prop._get_strategy(DeferredColumnLoader)
+                    instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement))
+            return post_execute
+        else:
+            return None
 
-        return post_execute
+    def _deferred_inheritance_condition(self, base_mapper, needs_tables):
+        def visit_binary(binary):
+            leftcol = binary.left
+            rightcol = binary.right
+            if leftcol is None or rightcol is None:
+                return
+            if leftcol.table not in needs_tables:
+                binary.left = sql.bindparam(None, None, type_=binary.right.type)
+                param_names.append((leftcol, binary.left))
+            elif rightcol not in needs_tables:
+                binary.right = sql.bindparam(None, None, type_=binary.right.type)
+                param_names.append((rightcol, binary.right))
+
+        allconds = []
+        param_names = []
+
+        for mapper in self.iterate_to_root():
+            if mapper is base_mapper:
+                break
+            allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary))
+
+        return sql.and_(*allconds), param_names
             
 Mapper.logger = logging.class_logger(Mapper)
 
 
     return hasattr(object, '_entity_name')
 
+object_session = None
+
+def _load_scalar_attributes(instance, attribute_names):
+    global object_session
+    if not object_session:
+        from sqlalchemy.orm.session import object_session
+        
+    if object_session(instance).query(object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
+        raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+
 def _state_mapper(state, entity_name=None):
     return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)]
 

lib/sqlalchemy/orm/session.py

         
         return util.IdentitySet(self.uow.new.values())
     new = property(new)
-    
+
 def _expire_state(state, attribute_names):
     """Standalone expire instance function.
 
     If the list is None or blank, the entire instance is expired.
     """
 
-    if state.trigger is None:
-        def load_attributes(instance, attribute_names):
-            if object_session(instance).query(instance.__class__)._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None:
-                raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance))
-        state.trigger = load_attributes
-
     state.expire_attributes(attribute_names)
 
 register_attribute = unitofwork.register_attribute

lib/sqlalchemy/orm/strategies.py

 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.sql import visitors, expression, operators
 from sqlalchemy.orm import mapper, attributes
-from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption
+from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path
 from sqlalchemy.orm import session as sessionlib
 from sqlalchemy.orm import util as mapperutil
 
             if self._should_log_debug:
                 self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key))
             return (new_execute, None, None)
-
-        # our mapped column is not present in the row.  check if we need to initialize a polymorphic
-        # row fetcher used by inheritance.
-        (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', mapper), (None, None))
-
-        if hosted_mapper is None:
-            return (None, None, None)
-        
-        if hosted_mapper.polymorphic_fetch == 'deferred':
-            # 'deferred' polymorphic row fetcher, put a callable on the property.
-            # create a deferred column loader which will query the remaining not-yet-loaded tables in an inheritance load.
-            # the mapper for the object creates the WHERE criterion using the mapper who originally 
-            # "hosted" the query and the list of tables which are unloaded between the "hosted" mapper
-            # and this mapper.  (i.e. A->B->C, the query used mapper A.  therefore will need B's and C's tables
-            # in the query).
-            
-            # deferred loader strategy
-            strategy = self.parent_property._get_strategy(DeferredColumnLoader)
-            
-            # full list of ColumnProperty objects to be loaded in the deferred fetch
-            props = [p.key for p in mapper.iterate_properties if isinstance(p.strategy, ColumnLoader) and p.columns[0].table in needs_tables]
-
-            # TODO: we are somewhat duplicating efforts from mapper._get_poly_select_loader 
-            # and should look for ways to simplify.
-            cond, param_names = mapper._deferred_inheritance_condition(hosted_mapper, needs_tables)
-            statement = sql.select(needs_tables, cond, use_labels=True)
-            def create_statement(instance):
-                params = {}
-                for (c, bind) in param_names:
-                    # use the "committed" (database) version to get query column values
-                    params[bind] = mapper._get_committed_attr_by_column(instance, c)
-                return (statement, params)
-            
+        else:
             def new_execute(instance, row, isnew, **flags):
                 if isnew:
-                    instance._state.set_callable(self.key, strategy.setup_loader(instance, props=props, create_statement=create_statement))
-                    
+                    instance._state.expire_attributes([self.key])
             if self._should_log_debug:
-                self.logger.debug("Returning deferred column fetcher for %s %s" % (mapper, self.key))
-                
+                self.logger.debug("Deferring load for %s %s" % (mapper, self.key))
             return (new_execute, None, None)
-        else:  
-            # immediate polymorphic row fetcher.  no processing needed for this row.
-            if self._should_log_debug:
-                self.logger.debug("Returning no column fetcher for %s %s" % (mapper, self.key))
-            return (None, None, None)
-
 
 ColumnLoader.logger = logging.class_logger(ColumnLoader)
 
             self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs)
         
     def setup_loader(self, instance, props=None, create_statement=None):
-        localparent = mapper.object_mapper(instance, raiseerror=False)
-        if localparent is None:
+        if not mapper.has_mapper(instance):
             return None
+            
+        localparent = mapper.object_mapper(instance)
 
         # adjust for the ColumnProperty associated with the instance
         # not being our own ColumnProperty.  This can occur when entity_name
         prop = localparent.get_property(self.key)
         if prop is not self.parent_property:
             return prop._get_strategy(DeferredColumnLoader).setup_loader(instance)
-            
-        def lazyload():
-            if not mapper.has_identity(instance):
-                return None
-            
-            if props is not None:
-                group = props
-            elif self.group is not None:
-                group = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==self.group]
-            else:
-                group = [self.parent_property.key]
-            
-            # narrow the keys down to just those which aren't present on the instance
-            group = [k for k in group if k not in instance.__dict__]
-            
-            if self._should_log_debug:
-                self.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(instance, self.key), group and ','.join(group) or 'None'))
 
-            session = sessionlib.object_session(instance)
-            if session is None:
-                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
-            if create_statement is None:
-                ident = instance._instance_key[1]
-                session.query(localparent)._get(None, ident=ident, only_load_props=group, refresh_instance=instance._state)
-            else:
-                statement, params = create_statement(instance)
-                session.query(localparent).from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=instance._state)
-            return attributes.ATTR_WAS_SET
-        return lazyload
+        return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement)
                 
 DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader)
 
+class LoadDeferredColumns(object):
+    """callable, serializable loader object used by DeferredColumnLoader"""
+    
+    def __init__(self, instance, key, keys, optimizing_statement):
+        self.instance = instance
+        self.key = key
+        self.keys = keys
+        self.optimizing_statement = optimizing_statement
+
+    def __getstate__(self):
+        return {'instance':self.instance, 'key':self.key, 'keys':self.keys}
+    
+    def __setstate__(self, state):
+        self.instance = state['instance']
+        self.key = state['key']
+        self.keys = state['keys']
+        self.optimizing_statement = None
+        
+    def __call__(self):
+        if not mapper.has_identity(self.instance):
+            return None
+            
+        localparent = mapper.object_mapper(self.instance, raiseerror=False)
+        
+        prop = localparent.get_property(self.key)
+        strategy = prop._get_strategy(DeferredColumnLoader)
+
+        if self.keys:
+            toload = self.keys
+        elif strategy.group:
+            toload = [p.key for p in localparent.iterate_properties if isinstance(p.strategy, DeferredColumnLoader) and p.group==strategy.group]
+        else:
+            toload = [self.key]
+
+        # narrow the keys down to just those which have no history
+        group = [k for k in toload if k in self.instance._state.unmodified]
+
+        if strategy._should_log_debug:
+            strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None'))
+
+        session = sessionlib.object_session(self.instance)
+        if session is None:
+            raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key))
+
+        query = session.query(localparent)
+        if not self.optimizing_statement:
+            ident = self.instance._instance_key[1]
+            query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state)
+        else:
+            statement, params = self.optimizing_statement(self.instance)
+            query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state)
+        return attributes.ATTR_WAS_SET
+
 class DeferredOption(StrategizedOption):
     def __init__(self, key, defer=False):
         super(DeferredOption, self).__init__(key)
 class LazyLoader(AbstractRelationLoader):
     def init(self):
         super(LazyLoader, self).init()
-        (self.lazywhere, self.lazybinds, self.lazyreverse) = self._create_lazy_clause(self)
+        (self.lazywhere, self.lazybinds, self.equated_columns) = self._create_lazy_clause(self)
         
         self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.lazywhere))
 
 
     def lazy_clause(self, instance, reverse_direction=False):
         if instance is None:
-            return self.lazy_none_clause(reverse_direction)
+            return self._lazy_none_clause(reverse_direction)
             
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
         else:
             (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
         bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
                 bindparam.value = mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key])
         return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam)
     
-    def lazy_none_clause(self, reverse_direction=False):
+    def _lazy_none_clause(self, reverse_direction=False):
         if not reverse_direction:
-            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.lazyreverse)
+            (criterion, lazybinds, rev) = (self.lazywhere, self.lazybinds, self.equated_columns)
         else:
             (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction)
         bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds])
     def setup_loader(self, instance, options=None, path=None):
         if not mapper.has_mapper(instance):
             return None
-        else:
-            # adjust for the PropertyLoader associated with the instance
-            # not being our own PropertyLoader.  This can occur when entity_name
-            # mappers are used to map different versions of the same PropertyLoader
-            # to the class.
-            prop = mapper.object_mapper(instance).get_property(self.key)
-            if prop is not self.parent_property:
-                return prop._get_strategy(LazyLoader).setup_loader(instance)
 
-        def lazyload():
-            if self._should_log_debug:
-                self.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+        localparent = mapper.object_mapper(instance)
 
-            if not mapper.has_identity(instance):
-                return None
-
-            session = sessionlib.object_session(instance)
-            if session is None:
-                try:
-                    session = mapper.object_mapper(instance).get_session()
-                except exceptions.InvalidRequestError:
-                    raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
-
-            # if we have a simple straight-primary key load, use mapper.get()
-            # to possibly save a DB round trip
-            q = session.query(self.mapper).autoflush(False)
-            if path:
-                q = q._with_current_path(path)
-            if self.use_get:
-                params = {}
-                for col, bind in self.lazybinds.iteritems():
-                    # use the "committed" (database) version to get query column values
-                    params[bind.key] = self.parent._get_committed_attr_by_column(instance, col)
-                ident = []
-                nonnulls = False
-                for primary_key in self.select_mapper.primary_key: 
-                    bind = self.lazyreverse[primary_key]
-                    v = params[bind.key]
-                    if v is not None:
-                        nonnulls = True
-                    ident.append(v)
-                if not nonnulls:
-                    return None
-                if options:
-                    q = q._conditional_options(*options)
-                return q.get(ident)
-            elif self.order_by is not False:
-                q = q.order_by(self.order_by)
-            elif self.secondary is not None and self.secondary.default_order_by() is not None:
-                q = q.order_by(self.secondary.default_order_by())
-
-            if options:
-                q = q._conditional_options(*options)
-            q = q.filter(self.lazy_clause(instance))
-
-            result = q.all()
-            if self.uselist:
-                return result
-            else:
-                if result:
-                    return result[0]
-                else:
-                    return None
-
-        return lazyload
+        # adjust for the PropertyLoader associated with the instance
+        # not being our own PropertyLoader.  This can occur when entity_name
+        # mappers are used to map different versions of the same PropertyLoader
+        # to the class.
+        prop = localparent.get_property(self.key)
+        if prop is not self.parent_property:
+            return prop._get_strategy(LazyLoader).setup_loader(instance)
+        
+        return LoadLazyAttribute(instance, self.key, options, path)
 
     def create_row_processor(self, selectcontext, mapper, row):
         if not self.is_class_level or len(selectcontext.options):
         (primaryjoin, secondaryjoin, remote_side) = (prop.polymorphic_primaryjoin, prop.polymorphic_secondaryjoin, prop.remote_side)
         
         binds = {}
-        reverse = {}
+        equated_columns = {}
 
         def should_bind(targetcol, othercol):
             if reverse_direction and not secondaryjoin:
                 return
             leftcol = binary.left
             rightcol = binary.right
-            
+
+            equated_columns[rightcol] = leftcol
+            equated_columns[leftcol] = rightcol
+
             if should_bind(leftcol, rightcol):
-                col = leftcol
-                binary.left = binds.setdefault(leftcol,
-                        sql.bindparam(None, None, type_=binary.right.type))
-                reverse[rightcol] = binds[col]
+                binary.left = binds[leftcol] = sql.bindparam(None, None, type_=binary.right.type)
 
             # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1",
             # which can happen in rare cases (test/orm/relationships.py RelationTest2)
             if leftcol is not rightcol and should_bind(rightcol, leftcol):
-                col = rightcol
-                binary.right = binds.setdefault(rightcol,
-                        sql.bindparam(None, None, type_=binary.left.type))
-                reverse[leftcol] = binds[col]
+                binary.right = binds[rightcol] = sql.bindparam(None, None, type_=binary.left.type)
 
         lazywhere = primaryjoin
         
             if reverse_direction:
                 secondaryjoin = visitors.traverse(secondaryjoin, clone=True, visit_binary=visit_binary)
             lazywhere = sql.and_(lazywhere, secondaryjoin)
-        return (lazywhere, binds, reverse)
+        return (lazywhere, binds, equated_columns)
     _create_lazy_clause = classmethod(_create_lazy_clause)
     
 LazyLoader.logger = logging.class_logger(LazyLoader)
 
+class LoadLazyAttribute(object):
+    """callable, serializable loader object used by LazyLoader"""
+
+    def __init__(self, instance, key, options, path):
+        self.instance = instance
+        self.key = key
+        self.options = options
+        self.path = path
+        
+    def __getstate__(self):
+        return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)}
+
+    def __setstate__(self, state):
+        self.instance = state['instance']
+        self.key = state['key']
+        self.options= state['options']
+        self.path = deserialize_path(state['path'])
+        
+    def __call__(self):
+        instance = self.instance
+        
+        if not mapper.has_identity(instance):
+            return None
+
+        instance_mapper = mapper.object_mapper(instance)
+        prop = instance_mapper.get_property(self.key)
+        strategy = prop._get_strategy(LazyLoader)
+        
+        if strategy._should_log_debug:
+            strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance)))
+
+        session = sessionlib.object_session(instance)
+        if session is None:
+            try:
+                session = instance_mapper.get_session()
+            except exceptions.InvalidRequestError:
+                raise exceptions.InvalidRequestError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key))
+
+        q = session.query(prop.mapper).autoflush(False)
+        if self.path:
+            q = q._with_current_path(self.path)
+            
+        # if we have a simple primary key load, use mapper.get()
+        # to possibly save a DB round trip
+        if strategy.use_get:
+            ident = []
+            allnulls = True
+            for primary_key in prop.select_mapper.primary_key: 
+                val = instance_mapper._get_committed_attr_by_column(instance, strategy.equated_columns[primary_key])
+                allnulls = allnulls and val is None
+                ident.append(val)
+            if allnulls:
+                return None
+            if self.options:
+                q = q._conditional_options(*self.options)
+            return q.get(ident)
+            
+        if strategy.order_by is not False:
+            q = q.order_by(strategy.order_by)
+        elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None:
+            q = q.order_by(strategy.secondary.default_order_by())
+
+        if self.options:
+            q = q._conditional_options(*self.options)
+        q = q.filter(strategy.lazy_clause(instance))
+
+        result = q.all()
+        if strategy.uselist:
+            return result
+        else:
+            if result:
+                return result[0]
+            else:
+                return None
+        
 
 class EagerLoader(AbstractRelationLoader):
     """Loads related objects inline with a parent query."""
             if self._should_log_debug:
                 self.logger.debug("eager loader %s degrading to lazy loader" % str(self))
             return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row)
-        
-            
+
     def __str__(self):
         return str(self.parent) + "." + self.key
         

lib/sqlalchemy/orm/util.py

 
 def state_str(state):
     """Return a string describing an instance."""
-
-    return state.class_.__name__ + "@" + hex(id(state.obj()))
+    if state is None:
+        return "None"
+    else:
+        return state.class_.__name__ + "@" + hex(id(state.obj()))
 
 def attribute_str(instance, attribute):
     return instance_str(instance) + "." + attribute

lib/sqlalchemy/sql/expression.py

     'subquery', 'table', 'text', 'union', 'union_all', 'update', ]
 
 
-BIND_PARAMS = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
 
 def desc(column):
     """Return a descending ``ORDER BY`` clause element.
 
     __visit_name__ = 'textclause'
 
+    _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
+
     def __init__(self, text = "", bind=None, bindparams=None, typemap=None):
         self._bind = bind
         self.bindparams = {}
 
         # scan the string and search for bind parameter names, add them
         # to the list of bindparams
-        self.text = BIND_PARAMS.sub(repl, text)
+        self.text = self._bind_params_regex.sub(repl, text)
         if bindparams is not None:
             for b in bindparams:
                 self.bindparams[b.key] = b

test/orm/alltests.py

         'orm.relationships',
         'orm.association',
         'orm.merge',
+        'orm.pickled',
         'orm.memusage',
         
         'orm.cycles',

test/orm/attributes.py

         self.assert_(o4.mt2[0].a == 'abcde')
         self.assert_(o4.mt2[0].b is None)
 
+    def test_deferred(self):
+        class Foo(object):pass
+        
+        data = {'a':'this is a', 'b':12}
+        def loader(instance, keys):
+            for k in keys:
+                instance.__dict__[k] = data[k]
+            return attributes.ATTR_WAS_SET
+            
+        attributes.register_class(Foo, deferred_scalar_loader=loader)
+        attributes.register_attribute(Foo, 'a', uselist=False, useobject=False)
+        attributes.register_attribute(Foo, 'b', uselist=False, useobject=False)
+        
+        f = Foo()
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+        
+        f.a = "this is some new a"
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+
+        f._state.expire_attributes(None)
+        f.a = "this is another new a"
+        self.assertEquals(f.a, "this is another new a")
+        self.assertEquals(f.b, 12)
+
+        f._state.expire_attributes(None)
+        self.assertEquals(f.a, "this is a")
+        self.assertEquals(f.b, 12)
+
+        del f.a
+        self.assertEquals(f.a, None)
+        self.assertEquals(f.b, 12)
+        
+        f._state.commit_all()
+        self.assertEquals(f.a, None)
+        self.assertEquals(f.b, 12)
+
+    def test_deferred_pickleable(self):
+        data = {'a':'this is a', 'b':12}
+        def loader(instance, keys):
+            for k in keys:
+                instance.__dict__[k] = data[k]
+            return attributes.ATTR_WAS_SET
+            
+        attributes.register_class(MyTest, deferred_scalar_loader=loader)
+        attributes.register_attribute(MyTest, 'a', uselist=False, useobject=False)
+        attributes.register_attribute(MyTest, 'b', uselist=False, useobject=False)
+        
+        m = MyTest()
+        m._state.expire_attributes(None)
+        assert 'a' not in m.__dict__
+        m2 = pickle.loads(pickle.dumps(m))
+        assert 'a' not in m2.__dict__
+        self.assertEquals(m2.a, "this is a")
+        self.assertEquals(m2.b, 12)
+        
     def test_list(self):
         class User(object):pass
         class Address(object):pass
         self.assertEquals(attributes.get_history(f._state, 'bars'), ([bar4], [], []))
 
         lazy_load = [bar1, bar2, bar3]
-        f._state.trigger = lazyload(f)
         f._state.expire_attributes(['bars'])
         self.assertEquals(attributes.get_history(f._state, 'bars'), ([], [bar1, bar2, bar3], []))
         

test/orm/expire.py

         self.assert_sql_count(testbase.db, go, 1)
         assert 'name' in u.__dict__
 
-        # we're changing the database here, so if this test fails in the middle,
-        # it'll screw up the other tests which are hardcoded to 7/'jack'
         u.name = 'foo'
         sess.flush()
         # change the value in the DB
         # test that it refreshed
         assert u.__dict__['name'] == 'jack'
 
-        # object should be back to normal now,
-        # this should *not* produce a SELECT statement (not tested here though....)
-        assert u.name == 'jack'
+        def go():
+            assert u.name == 'jack'
+        self.assert_sql_count(testbase.db, go, 0)
     
     def test_expire_doesntload_on_set(self):
         mapper(User, users)
             assert o.isopen == 1
         self.assert_sql_count(testbase.db, go, 1)
         assert o.description == 'order 3 modified'
+
+        del o.description
+        assert "description" not in o.__dict__
+        sess.expire(o, ['isopen'])
+        sess.query(Order).all()
+        assert o.isopen == 1
+        assert "description" not in o.__dict__
+
+        assert o.description is None
         
     def test_expire_committed(self):
         """test that the committed state of the attribute receives the most recent DB data"""
         def go():
             assert u.addresses[0].email_address == 'jack@bean.com'
             assert u.name == 'jack'
-        # one load
-        self.assert_sql_count(testbase.db, go, 1)
+        # two loads, since relation() + scalar are 
+        # separate right now
+        self.assert_sql_count(testbase.db, go, 2)
         assert 'name' in u.__dict__
         assert 'addresses' in u.__dict__
 
+        sess.expire(u, ['name', 'addresses'])
+        assert 'name' not in u.__dict__
+        assert 'addresses' not in u.__dict__
+
     def test_partial_expire(self):
         mapper(Order, orders)
 
         s.expire(u)
 
         # get the attribute, it refreshes
+        print "OK------"
+#        print u.__dict__
+#        print u._state.callables
         assert u.name == 'jack'
         assert id(a) not in [id(x) for x in u.addresses]
 

test/orm/mapper.py

         a = s.query(Address).from_statement(select([addresses.c.address_id, addresses.c.user_id])).first()
         assert a.user_id == 7
         assert a.address_id == 1
-        assert a.email_address is None
+        # email address auto-defers
+        assert 'email_addres' not in a.__dict__
+        assert a.email_address == 'jack@bean.com'
 
     def test_badconstructor(self):
         """test that if the construction of a mapped class fails, the instnace does not get placed in the session"""

test/orm/pickled.py

+import testbase
+from sqlalchemy import *
+from sqlalchemy import exceptions
+from sqlalchemy.orm import *
+from testlib import *
+from testlib.fixtures import *
+import pickle
+
+class EmailUser(User):
+    pass
+    
+class PickleTest(FixtureTest):
+    keep_mappers = False
+    keep_data = False
+    
+    def test_transient(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess.save(u2)
+        sess.flush()
+        
+        sess.clear()
+        
+        self.assertEquals(u1, sess.query(User).get(u2.id))
+    
+    def test_class_deferred_cols(self):
+        mapper(User, users, properties={
+            'name':deferred(users.c.name),
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses, properties={
+            'email_address':deferred(addresses.c.email_address)
+        })
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+        u1 = sess.query(User).get(u1.id)
+        assert 'name' not in u1.__dict__
+        assert 'addresses' not in u1.__dict__
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess2 = create_session()
+        sess2.update(u2)
+        self.assertEquals(u2.name, 'ed')
+        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+        
+    def test_instance_deferred_cols(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref="user")
+        })
+        mapper(Address, addresses)
+        
+        sess = create_session()
+        u1 = User(name='ed')
+        u1.addresses.append(Address(email_address='ed@bar.com'))
+        sess.save(u1)
+        sess.flush()
+        sess.clear()
+        
+        u1 = sess.query(User).options(defer('name'), defer('addresses.email_address')).get(u1.id)
+        assert 'name' not in u1.__dict__
+        assert 'addresses' not in u1.__dict__
+        
+        u2 = pickle.loads(pickle.dumps(u1))
+        sess2 = create_session()
+        sess2.update(u2)
+        self.assertEquals(u2.name, 'ed')
+        assert 'addresses' not in u1.__dict__
+        ad = u2.addresses[0]
+        assert 'email_address' not in ad.__dict__
+        self.assertEquals(ad.email_address, 'ed@bar.com')
+        self.assertEquals(u2, User(name='ed', addresses=[Address(email_address='ed@bar.com')]))
+
+class PolymorphicDeferredTest(ORMTest):
+    def define_tables(self, metadata):
+        global users, email_users
+        users = Table('users', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('name', String(30)),
+            Column('type', String(30)),
+            )
+        email_users = Table('email_users', metadata,
+            Column('id', Integer, ForeignKey('users.id'), primary_key=True),
+            Column('email_address', String(30))
+            )
+            
+    def test_polymorphic_deferred(self):
+        mapper(User, users, polymorphic_identity='user', polymorphic_on=users.c.type, polymorphic_fetch='deferred')
+        mapper(EmailUser, email_users, inherits=User, polymorphic_identity='emailuser')
+        
+        eu = EmailUser(name="user1", email_address='foo@bar.com')
+        sess = create_session()
+        sess.save(eu)
+        sess.flush()
+        sess.clear()
+        
+        eu = sess.query(User).first()
+        eu2 = pickle.loads(pickle.dumps(eu))
+        sess2 = create_session()
+        sess2.update(eu2)
+        assert 'email_address' not in eu2.__dict__
+        self.assertEquals(eu2.email_address, 'foo@bar.com')
+        
+        
+        
+
+if __name__ == '__main__':
+    testbase.main()