Commits

Mike Bayer committed c1c8a68

refinement pass thru savepoint API

Comments (0)

Files changed (4)

lib/sqlalchemy/orm/attributes.py

                 
         self.modified = True
     
-    # TODO: reorganize savepoints so that the savepoint record is managed externally to the InstanceState.
-    # InstanceState only references the most recent savepoint so that its available in modified_event().
-    
-    def set_savepoint(self, id_=None):
+    def set_savepoint(self, id_):
+        """set a savepoint with the given id."""
+        
         savepoint = {}
         for key in self.manager.mutable_attributes:
             if key in self.dict:
                 savepoint[key] = self.manager[key].impl.copy(self.dict[key])
-                
+        
         self.savepoints.append((savepoint, self.parents.copy(), self.pending.copy(), self.committed_state.copy(), self.modified, id_))
 
-    def remove_savepoint(self, id_=None):
-        if not self.savepoints:
-            raise sa_exc.AssertionError("Savepoint id %s does not exist"  % (id_))
+    def remove_savepoint(self, id_):
+        """remove a given savepoint"""
+        
+        try:
+            sp = self.savepoints.pop()
+        except IndexError:
+            raise sa_exc.InvalidRequestError("No savepoints are set; can't remove savepoint.")
             
-        sp = self.savepoints.pop()
-        spid = sp[5]
-        if spid != id_:
-            raise sa_exc.AssertionError("Savepoint id %s does not match %s"  % (spid, id_))
+        if sp[5] != id_:
+            raise sa_exc.AssertionError("Savepoint id %s does not match %s"  % (sp[5], id_))
 
-    def rollback(self, id_=None):
-        if not self.savepoints:
+    def rollback(self, id_):
+        """roll back to a given savepoint"""
+        
+        try:
+            (savepoint, self.parents, self.pending, self.committed_state, self.modified, spid) = self.savepoints.pop()
+        except IndexError:
             raise sa_exc.InvalidRequestError("No savepoints are set; can't rollback.")
-        
-        (savepoint, self.parents, self.pending, self.committed_state, self.modified, spid) = self.savepoints.pop()
+            
         if spid != id_:
             raise sa_exc.AssertionError("Savepoint id %s does not match %s"  % (spid, id_))
         
                 self.expired_attributes.remove(key)
                 self.callables.pop(key, None)
 
-
-    def commit_all(self):
+    def commit_all(self, savepoint_id=None):
         """commit all attributes unconditionally.
 
         This is used after a flush() or a full load/refresh
          - the "modified" flag is set to False
          - any "expired" markers/callables are removed.
 
-
         Attributes marked as "expired" can potentially remain "expired" after this step
         if a value was not populated in state.dict.
+        
+        The argument "savepoint_id" indicates that a savepoint should be 
+        set up along with the commit.  When present, the functionality of 
+        "set_savepoint()" is inlined here for better performance than a 
+        separate call.
+        
         """
         self.committed_state = {}
-
+        
         # unexpire attributes which have loaded
-        for key in list(self.expired_attributes):
-            if key in self.dict:
-                self.expired_attributes.remove(key)
+        if self.expired_attributes:
+            for key in self.expired_attributes.intersection(self.dict):
                 self.callables.pop(key, None)
+            self.expired_attributes.difference_update(self.dict)
         
         for key in self.manager.mutable_attributes:
             if key in self.dict:
                 self.manager[key].impl.commit_to_state(self, self.committed_state)
+
+        if savepoint_id:
+            cstate = self.committed_state.copy()
+            self.savepoints.append((cstate, self.parents.copy(), self.pending.copy(), cstate, False, savepoint_id))
                     
         self.modified = self.expired = False
         self._strong_obj = None

lib/sqlalchemy/orm/session.py

         else:
             snapshot_id = None    
         for state in states:
-            state.commit_all()
-            if snapshot_id:
-                state.set_savepoint(snapshot_id)
+            state.commit_all(savepoint_id=snapshot_id)
 
     def get(self, class_, ident, entity_name=None):
         """Return an instance of the object based on the given

test/orm/attr_rollback.py

 
         f = Foo()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = data1
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist1, [], [])
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
 
         assert f.x == empty
     def test_needs_savepoint(self):
         f = Foo()
         f.x = data1
-        self.assertRaises(sa_exc.InvalidRequestError, attributes.instance_state(f).rollback)
+        self.assertRaises(sa_exc.InvalidRequestError, attributes.instance_state(f).rollback, 1)
+
+        self.assertRaises(sa_exc.InvalidRequestError, attributes.instance_state(f).remove_savepoint, 1)
+    
+    def test_savepoint_matchup(self):
+        f = Foo()
+        attributes.instance_state(f).set_savepoint(1)
+        attributes.instance_state(f).set_savepoint(2)
+        self.assertRaises(sa_exc.AssertionError, attributes.instance_state(f).rollback, 1)
+        
+        f = Foo()
+        attributes.instance_state(f).set_savepoint(1)
+        attributes.instance_state(f).set_savepoint(2)
+        self.assertRaises(sa_exc.AssertionError, attributes.instance_state(f).remove_savepoint, 1)
+
+        f = Foo()
+        attributes.instance_state(f).set_savepoint(1)
+        self.assertRaises(sa_exc.AssertionError, attributes.instance_state(f).rollback, 2)
+
+        f = Foo()
+        attributes.instance_state(f).set_savepoint(1)
+        self.assertRaises(sa_exc.AssertionError, attributes.instance_state(f).remove_savepoint, 2)
         
     def test_rback_to_set(self):
         f = Foo()
         f.x = data1
         attributes.instance_state(f).commit_all()
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = empty
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
     def test_rback_savepoint_to_set(self):
         f = Foo()
         f.x = data1
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = empty
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert f.x == data1
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist1, [], [])
         
         f = Foo()
         f.x = data1
         attributes.instance_state(f).commit_all()
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
     def test_rback_savepoint_rback_to_committed(self):
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
 
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(2)
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
 
         f.x = data3
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist1)
         
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(2)
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
     def test_rback_savepoint_commit(self):
 
         f.x = data2
         aeq(attributes.get_history(attributes.instance_state(f), 'x'), (hist2, [], hist1))
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         aeq(attributes.get_history(attributes.instance_state(f), 'x'), (hist2, [], hist1))
 
         f.x = data3
         aeq(attributes.get_history(attributes.instance_state(f), 'x'), (hist3, [], hist1))
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         aeq(attributes.get_history(attributes.instance_state(f), 'x'), (hist2, [], hist1))
 
         attributes.instance_state(f).commit_all()
 
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
 
         f.x = data3
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist1)
 
-        attributes.instance_state(f).remove_savepoint()
+        attributes.instance_state(f).remove_savepoint(1)
 
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
     def test_rback_savepoint_to_empty(self):
         f = Foo()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
         f.x = data3
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], [])
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
     
     def test_commit_to_savepoint(self):
         
         f = Foo()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = data1
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
 
     def test_multiple_commit_to_savepoint_rback(self):
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
         # begin transaction
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         
         # change things
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
         
         # rollback transaction
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         
         # back to beginning
         assert f.x == data1
         f.x = data1
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
         attributes.instance_state(f).commit_all()
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
         
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(2)
         
         f.x = data3
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist2)
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
         
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(2)
         
         assert f.x == data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
         
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
+        assert f.x == data1
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
+
+    def test_multiple_commit_to_nested_savepoint_rback_inline(self):
+
+        f = Foo()
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
+        f.x = data1
+        attributes.instance_state(f).commit_all(savepoint_id="X")
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
+
+        f.x = data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
+        attributes.instance_state(f).commit_all(savepoint_id="Y")
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+
+        f.x = data3
+        assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist2)
+        attributes.instance_state(f).commit_all(savepoint_id="Z")
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
+
+        attributes.instance_state(f).rollback(id_="Z")
+        attributes.instance_state(f).rollback(id_="Y")
+
+        assert f.x == data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+
+        attributes.instance_state(f).rollback(id_="X")
         assert f.x == data1
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
 
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(2)
 
         f.x = data3
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist2)
         attributes.instance_state(f).commit_all()
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(2)
+
+        assert f.x == data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+
+        attributes.instance_state(f).commit_all()
+        assert f.x == data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+
+    def test_multiple_commit_to_nested_savepoint_commit_inline(self):
+
+        f = Foo()
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
+        f.x = data1
+        attributes.instance_state(f).commit_all(savepoint_id="X")
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
+
+        f.x = data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], hist1)
+        attributes.instance_state(f).commit_all(savepoint_id="Y")
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+
+        f.x = data3
+        assert attributes.get_history(attributes.instance_state(f), 'x') == (hist3, [], hist2)
+        attributes.instance_state(f).commit_all()
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist3, [])
+
+        attributes.instance_state(f).rollback(id_="Y")
 
         assert f.x == data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
         attributes.instance_state(f).commit_all()
         f.x = data1
         
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = data2
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert f.x == data1
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist1, [], [])
 
         f.x = data1
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist1, [], [])
         
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         f.x = data2
         assert attributes.get_history(attributes.instance_state(f), 'x') == (hist2, [], [])
         
-        attributes.instance_state(f).remove_savepoint()
+        attributes.instance_state(f).remove_savepoint(1)
         attributes.instance_state(f).commit_all()
 
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
         b1, b2, b3, b4, b5 = Bar(), Bar(), Bar(), Bar(), Bar()
         f = Foo()
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         
         f.x.append(b2)
         f.x.append(b3)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
         assert f.x == make_collection([b2, b3])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(2)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
 
         f.x.remove(b3)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b1], [], [])
         assert f.x == make_collection([b2, b1])
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(2)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
         assert f.x == make_collection([b2, b3])
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], [], [])
 
     def test_collection_rback_savepoint_commit(self):
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
         assert f.x == make_collection([b2, b3])
 
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
 
         f.x.remove(b3)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b1], [], [])
         assert f.x == make_collection([b2, b1])
 
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([b2, b3], [], [])
         assert f.x == make_collection([b2, b3])
 
         assert not attributes.has_parent(Foo, b3, 'x', optimistic=False)
 
         for x in (f, b1, b2, b3):
-            x._foostate.set_savepoint()
+            x._foostate.set_savepoint(1)
         f.x.append(b3)
         f.x.remove(b2)
         assert attributes.has_parent(Foo, b1, 'x', optimistic=False)
         assert attributes.has_parent(Foo, b3, 'x', optimistic=False)
         
         for x in (f, b1, b2, b3):
-            x._foostate.rollback()
+            x._foostate.rollback(1)
 
         assert attributes.has_parent(Foo, b1, 'x', optimistic=False)
         assert attributes.has_parent(Foo, b2, 'x', optimistic=False)
 
         f.x = {'data':5}
         attributes.instance_state(f).commit_all()
-        attributes.instance_state(f).set_savepoint()
+        attributes.instance_state(f).set_savepoint(1)
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
 
         f.x['foo'] = 9
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([{'data':5, 'foo':9}], [], [{'data':5}])
         
-        attributes.instance_state(f).rollback()
+        attributes.instance_state(f).rollback(1)
         
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
         
         assert attributes.has_parent(Foo, b, 'x', optimistic=False)
         
         for x in (f, f2, b):
-            x._foostate.set_savepoint()
+            x._foostate.set_savepoint(1)
         f2.x = None
         f.x = b
         assert attributes.has_parent(Foo, b, 'x', optimistic=False)
         
         for x in (f, f2, b):
-            x._foostate.rollback()
+            x._foostate.rollback(1)
             
         assert f2.x == b
         assert attributes.has_parent(Foo, b, 'x', optimistic=False)

test/perf/masseagerload.py

     session.begin()
     query = session.query(Item)
     l = query.all()
-    session.commit()
     print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
 
 def all():