1. Christoph Zwerschke
  2. sqlalchemy

Commits

Mike Bayer  committed 97a4fde

- improved the attribute and state accounting performed by query.update() and query.delete()
- added autoflush support to same

  • Participants
  • Parent commits 8d1648a
  • Branches default

Comments (0)

Files changed (2)

File lib/sqlalchemy/orm/query.py

View file
  • Ignore whitespace
             select_stmt = context.statement.with_only_columns(primary_table.primary_key)
             matched_rows = session.execute(select_stmt).fetchall()
         
+        if self._autoflush:
+            session._autoflush()
         session.execute(delete_stmt)
         
         if synchronize_session == 'evaluate':
             objs_to_expunge = [obj for (cls, pk, entity_name),obj in session.identity_map.iteritems()
                 if issubclass(cls, target_cls) and eval_condition(obj)]
             for obj in objs_to_expunge:
-                session.expunge(obj)
+                session._remove_newly_deleted(attributes.instance_state(obj))
         elif synchronize_session == 'fetch':
             target_mapper = self._mapper_zero()
             for primary_key in matched_rows:
                 identity_key = target_mapper.identity_key_from_primary_key(list(primary_key))
                 if identity_key in session.identity_map:
-                    session.expunge(session.identity_map[identity_key])
+                    session._remove_newly_deleted(attributes.instance_state(session.identity_map[identity_key]))
 
     def update(self, values, synchronize_session='evaluate'):
         """EXPERIMENTAL"""
             select_stmt = context.statement.with_only_columns(primary_table.primary_key)
             matched_rows = session.execute(select_stmt).fetchall()
         
+        if self._autoflush:
+            session._autoflush()
         session.execute(update_stmt)
         
         if synchronize_session == 'evaluate':
             target_cls = self._mapper_zero().class_
             
             for (cls, pk, entity_name),obj in session.identity_map.iteritems():
+                evaluated_keys = value_evaluators.keys()
+                
                 if issubclass(cls, target_cls) and eval_condition(obj):
-                    for key,eval_value in value_evaluators.items():
-                        obj.__dict__[key] = eval_value(obj)
-        
+                    state = attributes.instance_state(obj)
+                    
+                    # only evaluate unmodified attributes
+                    to_evaluate = state.unmodified.intersection(evaluated_keys)
+                    for key in to_evaluate:
+                        state.dict[key] = value_evaluators[key](obj)
+                            
+                    state.commit(list(to_evaluate))
+                    
+                    # expire attributes with pending changes (there was no autoflush, so they are overwritten)
+                    state.expire_attributes(util.Set(evaluated_keys).difference(to_evaluate))
+                    
         elif synchronize_session == 'expire':
             target_mapper = self._mapper_zero()
             

File test/orm/query.py

View file
  • Ignore whitespace
                 b1
         )
         
-class UpdateTest(_base.MappedTest):
+class UpdateDeleteTest(_base.MappedTest):
     def define_tables(self, metadata):
         Table('users', metadata,
               Column('id', Integer, primary_key=True),
         eq_(sess.query(User).order_by(User.id).all(), [jack,jane])
         
     @testing.resolve_artifact_names
+    def test_delete_rollback(self):
+        sess = sessionmaker()()
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete()
+        assert john not in sess and jill not in sess
+        sess.rollback()
+        assert john in sess and jill in sess
+
+    @testing.resolve_artifact_names
+    def test_delete_rollback_with_fetch(self):
+        sess = sessionmaker()()
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        sess.query(User).filter(or_(User.name == 'john', User.name == 'jill')).delete(synchronize_session='fetch')
+        assert john not in sess and jill not in sess
+        sess.rollback()
+        assert john in sess and jill in sess
+        
+    @testing.resolve_artifact_names
     def test_delete_without_session_sync(self):
         sess = create_session(bind=testing.db, autocommit=False)
         
         eq_(sess.query(User.age).order_by(User.id).all(), zip([25,37,29,27]))
 
     @testing.resolve_artifact_names
+    def test_update_changes_resets_dirty(self):
+        sess = create_session(bind=testing.db, autocommit=False, autoflush=False)
+
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+        
+        john.age = 50
+        jack.age = 37
+        
+        # autoflush is false.  therefore our '50' and '37' are getting blown away by this operation.
+        
+        sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
+
+        for x in (john, jack, jill, jane):
+            assert not sess.is_modified(x)
+
+        eq_([john.age, jack.age, jill.age, jane.age], [25,37,29,27])
+        
+        john.age = 25
+        assert john in sess.dirty
+        assert jack in sess.dirty
+        assert jill not in sess.dirty
+        assert not sess.is_modified(john)
+        assert not sess.is_modified(jack)
+
+    @testing.resolve_artifact_names
+    def test_update_changes_with_autoflush(self):
+        sess = create_session(bind=testing.db, autocommit=False, autoflush=True)
+
+        john,jack,jill,jane = sess.query(User).order_by(User.id).all()
+
+        john.age = 50
+        jack.age = 37
+
+        sess.query(User).filter(User.age > 29).update({'age': User.age - 10})
+
+        for x in (john, jack, jill, jane):
+            assert not sess.is_modified(x)
+
+        eq_([john.age, jack.age, jill.age, jane.age], [40, 27, 29, 27])
+
+        john.age = 25
+        assert john in sess.dirty
+        assert jack not in sess.dirty
+        assert jill not in sess.dirty
+        assert sess.is_modified(john)
+        assert not sess.is_modified(jack)
+        
+        
+
+    @testing.resolve_artifact_names
     def test_update_with_expire_strategy(self):
         sess = create_session(bind=testing.db, autocommit=False)