implement version_id for bulk_save / bulk_update

Issue #3781 resolved
Mike Bayer repo owner created an issue
diff --git a/test/orm/test_bulk.py b/test/orm/test_bulk.py
index 7e1b052..6acd6d2 100644
--- a/test/orm/test_bulk.py
+++ b/test/orm/test_bulk.py
@@ -13,6 +13,56 @@ class BulkTest(testing.AssertsExecutionResults):
     run_define_tables = 'each'


+class BulkInsertUpdateVersionId(BulkTest, fixtures.MappedTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table('version_table', metadata,
+              Column('id', Integer, primary_key=True,
+                     test_needs_autoincrement=True),
+              Column('version_id', Integer, nullable=False),
+              Column('value', String(40), nullable=False))
+
+    @classmethod
+    def setup_classes(cls):
+        class Foo(cls.Comparable):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        Foo, version_table = cls.classes.Foo, cls.tables.version_table
+
+        mapper(Foo, version_table, version_id_col=version_table.c.version_id)
+
+    def test_bulk_insert_via_save(self):
+        Foo = self.classes.Foo
+
+        s = Session()
+
+        s.bulk_save_objects([Foo(value='value')])
+
+        eq_(
+            s.query(Foo).all(),
+            [Foo(version_id=1, value='value')]
+        )
+
+    def test_bulk_update_via_save(self):
+        Foo = self.classes.Foo
+
+        s = Session()
+
+        s.add(Foo(value='value'))
+        s.commit()
+
+        f1 = s.query(Foo).first()
+        f1.value = 'new value'
+        s.bulk_save_objects([f1])
+
+        eq_(
+            s.query(Foo).all(),
+            [Foo(version_id=2, value='new value')]
+        )
+
+
 class BulkInsertUpdateTest(BulkTest, _fixtures.FixtureTest):

     @classmethod

patch:

diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py
index 5d69f51..467f47f 100644
--- a/lib/sqlalchemy/orm/persistence.py
+++ b/lib/sqlalchemy/orm/persistence.py
@@ -82,11 +82,15 @@ def _bulk_update(mapper, mappings, session_transaction,

     cached_connections = _cached_connection_dict(base_mapper)

+    search_keys = mapper._primary_key_propkeys
+    if mapper._version_id_prop:
+        search_keys = set([mapper._version_id_prop.key]).union(search_keys)
+
     def _changed_dict(mapper, state):
         return dict(
             (k, v)
             for k, v in state.dict.items() if k in state.committed_state or k
-            in mapper._primary_key_propkeys
+            in search_keys
         )

     if isstates:

also, tricky, if you do a bulk update on an object that's also in the session, now the object you have locally is stale. because bulk doesn't do any object bookkeeping.

also, tricky, need to close

Comments (3)

  1. Mike Bayer reporter

    Consider version_id_prop when emitting bulk UPDATE

    The version id needs to be part of _changed_dict() so that the value is present to send to _emit_update_statements()

    Change-Id: Ia85f0ef7714296a75cdc6c88674805afbbe752c8 Fixes: #3781

    → <<cset c9d8a67b52d1>>

  2. Mike Bayer reporter

    Consider version_id_prop when emitting bulk UPDATE

    The version id needs to be part of _changed_dict() so that the value is present to send to _emit_update_statements()

    Change-Id: Ia85f0ef7714296a75cdc6c88674805afbbe752c8 Fixes: #3781

    → <<cset e2a976e91657>>

  3. Log in to comment