Commits

Mike Bayer  committed 96f5ae2

- session.merge() will not expire attributes on the returned
instance if that instance is "pending". [ticket:1789]

  • Participants
  • Parent commits 0f1b087

Comments (0)

Files changed (9)

 - orm
   - Fixed regression introduced in 0.6.0 involving improper
     history accounting on mutable attributes.  [ticket:1782]
-    
+  
+  - session.merge() will not expire attributes on the returned
+    instance if that instance is "pending".  [ticket:1789]
+
 - sql
   - Fixed bug that prevented implicit RETURNING from functioning
     properly with composite primary key that contained zeroes.

File lib/sqlalchemy/orm/dynamic.py

     attributes, object_session, util as mapperutil, strategies, object_mapper
     )
 from sqlalchemy.orm.query import Query
-from sqlalchemy.orm.util import _state_has_identity, has_identity
+from sqlalchemy.orm.util import has_identity
 from sqlalchemy.orm import attributes, collections
 
 class DynaLoader(strategies.AbstractRelationshipLoader):
         collection_history = self._modified_event(state, dict_)
         new_values = list(iterable)
 
-        if _state_has_identity(state):
+        if state.has_identity:
             old_collection = list(self.get(state, dict_))
         else:
             old_collection = []

File lib/sqlalchemy/orm/mapper.py

     MapperProperty, EXT_CONTINUE, PropComparator
     )
 from sqlalchemy.orm.util import (
-     ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _state_has_identity,
+     ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, 
      _state_mapper, class_mapper, instance_str, state_str,
      )
 
                 # column is coming in after _readonly_props was initialized; check
                 # for 'readonly'
                 if hasattr(self, '_readonly_props') and \
-                    (not hasattr(col, 'table') or col.table not in self._cols_by_table):
+                    (not hasattr(col, 'table') or 
+                    col.table not in self._cols_by_table):
                         self._readonly_props.add(prop)
 
             else:
-                # if column is coming in after _cols_by_table was initialized, ensure the col is in the
-                # right set
-                if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]:
+                # if column is coming in after _cols_by_table was 
+                # initialized, ensure the col is in the right set
+                if hasattr(self, '_cols_by_table') and \
+                                    col.table in self._cols_by_table and \
+                                    col not in self._cols_by_table[col.table]:
                     self._cols_by_table[col.table].add(col)
             
             # if this ColumnProperty represents the "polymorphic discriminator"
             # column, mark it.  We'll need this when rendering columns
             # in SELECT statements.
             if not hasattr(prop, '_is_polymorphic_discriminator'):
-                prop._is_polymorphic_discriminator = (col is self.polymorphic_on or prop.columns[0] is self.polymorphic_on)
+                prop._is_polymorphic_discriminator = \
+                                    (col is self.polymorphic_on or
+                                    prop.columns[0] is self.polymorphic_on)
                 
             self.columns[key] = col
             for col in prop.columns:
         for mapper in self.iterate_to_root():
             for (key, cls) in mapper.delete_orphans:
                 if attributes.manager_of_class(cls).has_parent(
-                    state, key, optimistic=_state_has_identity(state)):
+                    state, key, optimistic=state.has_identity):
                     return False
             o = o or bool(mapper.delete_orphans)
         return o
                 connection_callable(self, state.obj()) or \
                 connection
 
-            has_identity = _state_has_identity(state)
+            has_identity = state.has_identity
             mapper = _state_mapper(state)
             instance_key = state.key or mapper._identity_key_from_state(state)
 
                         c = connection.execute(statement.values(value_params), params)
                         
                     mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c, c.last_updated_params(), value_params)
+                                        state, state_dict, c, 
+                                        c.last_updated_params(), value_params)
 
                     rows += c.rowcount
 
                     if primary_key is not None:
                         # set primary key attributes
                         for i, col in enumerate(mapper._pks_by_table[table]):
-                            if mapper._get_state_attr_by_column(state, state_dict, col) is None and \
-                                                                len(primary_key) > i:
-                                mapper._set_state_attr_by_column(state, state_dict, col, primary_key[i])
+                            if mapper._get_state_attr_by_column(state, state_dict, col) \
+                                        is None and len(primary_key) > i:
+                                mapper._set_state_attr_by_column(state, state_dict, col,
+                                                                    primary_key[i])
                                 
                     mapper._postfetch(uowtransaction, table, 
-                                        state, state_dict, c, c.last_inserted_params(), value_params)
+                                        state, state_dict, c, c.last_inserted_params(),
+                                        value_params)
 
         if not postupdate:
             for state, state_dict, mapper, connection, has_identity, \
                 readonly = state.unmodified.intersection(
                     p.key for p in mapper._readonly_props
                 )
-
+                
                 if readonly:
                     _expire_state(state, state.dict, readonly)
 
             tups.append((state, 
                     state.dict,
                     _state_mapper(state), 
-                    _state_has_identity(state),
+                    state.has_identity,
                     conn))
 
         table_to_mapper = self._sorted_tables
         raise orm_exc.DetachedInstanceError("Instance %s is not bound to a Session; "
                     "attribute refresh operation cannot proceed" % (state_str(state)))
 
-    has_key = _state_has_identity(state)
-
+    has_key = state.has_identity
+    
     result = False
     if mapper.inherits and not mapper.concrete:
         statement = mapper._optimized_get_statement(state, attribute_names)
             identity_key = state.key
         else:
             identity_key = mapper._identity_key_from_state(state)
+        
         result = session.query(mapper)._get(
                                             identity_key, 
                                             refresh_state=state, 
     # if instance is pending, a refresh operation 
     # may not complete (even if PK attributes are assigned)
     if has_key and result is None:
-        raise orm_exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state))
+        raise orm_exc.ObjectDeletedError(
+                            "Instance '%s' has been deleted." % 
+                            state_str(state))

File lib/sqlalchemy/orm/properties.py

                                        self.columns[0], self.key))
 
     def copy(self):
-        return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns)
+        return ColumnProperty(
+                        deferred=self.deferred, 
+                        group=self.group, 
+                        *self.columns)
 
     def _getattr(self, state, dict_, column):
         return state.get_impl(self.key).get(state, dict_)
 
     def _getcommitted(self, state, dict_, column, passive=False):
-        return state.get_impl(self.key).get_committed_value(state, dict_, passive=passive)
+        return state.get_impl(self.key).\
+                    get_committed_value(state, dict_, passive=passive)
 
     def _setattr(self, state, dict_, value, column):
         state.get_impl(self.key).set(state, dict_, value, None)
 
-    def merge(self, session, source_state, source_dict, dest_state, dest_dict, load, _recursive):
+    def merge(self, session, source_state, source_dict, dest_state, 
+                                dest_dict, load, _recursive):
         if self.key in source_dict:
             value = source_dict[self.key]
         
                 impl = dest_state.get_impl(self.key)
                 impl.set(dest_state, dest_dict, value, None)
         else:
-            if self.key not in dest_dict:
+            if dest_state.has_identity and self.key not in dest_dict:
                 dest_state.expire_attributes(dest_dict, [self.key])
                 
     def get_col_value(self, column, value):
             if self.adapter:
                 return self.adapter(self.prop.columns[0])
             else:
-                return self.prop.columns[0]._annotate({"parententity": self.mapper, "parentmapper":self.mapper})
+                return self.prop.columns[0]._annotate({
+                                                "parententity": self.mapper,
+                                                "parentmapper":self.mapper})
                 
         def operate(self, op, *other, **kwargs):
             return op(self.__clause_element__(), *other, **kwargs)

File lib/sqlalchemy/orm/session.py

 from sqlalchemy.orm.util import object_mapper as _object_mapper
 from sqlalchemy.orm.util import class_mapper as _class_mapper
 from sqlalchemy.orm.util import (
-    _class_to_mapper, _state_has_identity, _state_mapper,
+    _class_to_mapper, _state_mapper,
     )
 from sqlalchemy.orm.mapper import Mapper, _none_set
 from sqlalchemy.orm.unitofwork import UOWTransaction
             if state.key is None:
                 state.key = instance_key
             elif state.key != instance_key:
-                # primary key switch.
-                # use discard() in case another state has already replaced this
-                # one in the identity map (see test/orm/test_naturalpks.py ReversePKsTest)
+                # primary key switch. use discard() in case another 
+                # state has already replaced this one in the identity 
+                # map (see test/orm/test_naturalpks.py ReversePKsTest)
                 self.identity_map.discard(state)
                 state.key = instance_key
             
             
         for state in proc:
             is_orphan = _state_mapper(state)._is_orphan(state)
-            if is_orphan and not _state_has_identity(state):
+            if is_orphan and not state.has_identity:
                 path = ", nor ".join(
                     ["any parent '%s' instance "
                      "via that classes' '%s' attribute" %

File lib/sqlalchemy/orm/state.py

     @util.memoized_property
     def callables(self):
         return {}
+
+    @property
+    def has_identity(self):
+        return bool(self.key)
         
     def detach(self):
         if self.session_id:
         If all attributes are expired, the "expired" flag is set to True.
         
         """
+        # we would like to assert that 'self.key is not None' here, 
+        # but there are many cases where the mapper will expire
+        # a newly persisted instance within the flush, before the
+        # key is assigned, and even cases where the attribute refresh
+        # occurs fully, within the flush(), before this key is assigned.
+        # the key is assigned late within the flush() to assist in
+        # "key switch" bookkeeping scenarios.
+        
         if attribute_names is None:
             attribute_names = self.manager.keys()
             self.expired = True

File lib/sqlalchemy/orm/strategies.py

                                         path, adapter, **kwargs)
     
     def _class_level_loader(self, state):
-        if not mapperutil._state_has_identity(state):
+        if not state.has_identity:
             return None
             
         return LoadDeferredColumns(state, self.key)
         return criterion
         
     def _class_level_loader(self, state):
-        if not mapperutil._state_has_identity(state):
+        if not state.has_identity:
             return None
 
         return LoadLazyAttribute(state, self.key)

File lib/sqlalchemy/orm/util.py

 
 def has_identity(object):
     state = attributes.instance_state(object)
-    return _state_has_identity(state)
-
-def _state_has_identity(state):
-    return bool(state.key)
+    return state.has_identity
 
 def _is_mapped_class(cls):
     global mapperlib

File test/orm/test_merge.py

         sess.commit()
 
     @testing.resolve_artifact_names
+    def test_dont_expire_pending(self):
+        """test that pending instances aren't expired during a merge."""
+        
+        mapper(User, users)
+        u = User(id=7)
+        sess = create_session(autoflush=True, autocommit=False)
+        u = sess.merge(u)
+        assert not bool(attributes.instance_state(u).expired_attributes)
+        def go():
+            eq_(u.name, None)
+        self.assert_sql_count(testing.db, go, 0)
+    
+    @testing.resolve_artifact_names
     def test_option_state(self):
         """test that the merged takes on the MapperOption characteristics
         of that which is merged.