1. Ben Trofatter
  2. sqlalchemy-2663

Commits

Mike Bayer  committed 51c32cf

recheck the dirty list if extensions are present

  • Participants
  • Parent commits 64dfbe1
  • Branches default

Comments (0)

Files changed (2)

File lib/sqlalchemy/orm/session.py

View file
  • Ignore whitespace
             self.identity_map.modified = False
             return
 
-        flush_context = UOWTransaction(self)
+        flush_context   = UOWTransaction(self)
 
-        for ext in self.extensions:
-            ext.before_flush(self, flush_context, objects)
-
+        if self.extensions:
+            for ext in self.extensions:
+                ext.before_flush(self, flush_context, objects)
+            dirty = self._dirty_states
+            
         deleted = set(self._deleted)
         new = set(self._new)
 

File test/orm/session.py

View file
  • Ignore whitespace
                     if isinstance(obj, User):
                         x = session.query(User).filter(User.name=='another %s' % obj.name).one()
                         session.delete(x)
-                        
+                    
         sess = create_session(extension = MyExt(), autoflush=True)
         u = User(name='u1')
         sess.add(u)
         )
 
     @testing.resolve_artifact_names
+    def test_before_flush_affects_dirty(self):
+        mapper(User, users)
+        
+        class MyExt(sa.orm.session.SessionExtension):
+            def before_flush(self, session, flush_context, objects):
+                for obj in list(session.identity_map.values()):
+                    obj.name += " modified"
+                    
+        sess = create_session(extension = MyExt(), autoflush=True)
+        u = User(name='u1')
+        sess.add(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='u1')
+            ]
+        )
+        
+        sess.add(User(name='u2'))
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).order_by(User.name).all(), 
+            [
+                User(name='u1 modified'),
+                User(name='u2')
+            ]
+        )
+
+    @testing.resolve_artifact_names
     def test_reentrant_flush(self):
 
         mapper(User, users)