Commits

Mike Bayer  committed 4db7a7b

- Changes made to new, dirty and deleted
collections in
SessionExtension.before_flush() will take
effect for that flush.

  • Participants
  • Parent commits c22c08f
  • Branches rel_0_4

Comments (0)

Files changed (3)

       with "A=B" versus "B=A" leading to errors
       [ticket:1039]
 
+    - Changes made to new, dirty and deleted 
+      collections in
+      SessionExtension.before_flush() will take
+      effect for that flush.
+
 - mysql
     - Added MSMediumInteger type [ticket:1146].
 

File lib/sqlalchemy/orm/unitofwork.py

         if not dirty and not self.deleted and not self.new:
             return
         
+        flush_context = UOWTransaction(self, session)
+
+        if session.extension is not None:
+            session.extension.before_flush(session, flush_context, objects)
+            dirty = [x for x in self.identity_map.all_states()
+                if x.modified
+                or (x.class_._class_state.has_mutable_scalars and x.is_modified())
+            ]
+
         deleted = util.Set(self.deleted)
         new = util.Set(self.new)
         
         dirty = util.Set(dirty).difference(deleted)
-        
-        flush_context = UOWTransaction(self, session)
-
-        if session.extension is not None:
-            session.extension.before_flush(session, flush_context, objects)
 
         # create the set of all objects we want to operate upon
         if objects:

File test/orm/session.py

         conn = sess.connection()
         assert log == ['after_begin']
 
+    def test_before_flush_affects_dirty(self):
+        class User(fixtures.Base):
+            pass
+        mapper(User, users)
+
+        class MyExt(SessionExtension):
+            def before_flush(self, session, flush_context, objects):
+                for obj in list(session.identity_map.values()):
+                    obj.user_name += " modified"
+
+        sess = create_session(extension = MyExt(), autoflush=True)
+        u = User(user_name='u1')
+        sess.add(u)
+        sess.flush()
+        self.assertEquals(sess.query(User).order_by(User.user_name).all(),
+            [
+                User(user_name='u1')
+            ]
+        )
+
+        sess.add(User(user_name='u2'))
+        sess.flush()
+        sess.clear()
+        self.assertEquals(sess.query(User).order_by(User.user_name).all(),
+            [
+                User(user_name='u1 modified'),
+                User(user_name='u2')
+            ]
+        )
+
     def test_pickled_update(self):
         mapper(User, users)
         sess1 = create_session()