1. Daniel Miller
  2. sqlalchemy

Commits

Mike Bayer  committed 2f9b4a7

some more rollback tests and some uncertainty of what do to with expiry/lazy loaders

  • Participants
  • Parent commits 339f1db
  • Branches user_defined_state

Comments (0)

Files changed (3)

File lib/sqlalchemy/orm/attributes.py

View file
  • Ignore whitespace
         state.dict[self.key] = user_data
 
         state.commit([self.key])
+
         if self.key in state.pending:
             # pending items exist.  issue a modified event,
             # add/remove new items.
         """
 
         class_manager = self.manager
+        savepoint = self.savepoints and self.savepoints[-1][0] or None
         for key in keys:
             if key in self.dict and key in class_manager.mutable_attributes:
                 class_manager[key].impl.commit_to_state(self, self.committed_state)
+                if savepoint is not None:
+                    # ugh  - TODO: do we have to commit up the chain of all savepoints ??
+                    class_manager[key].impl.commit_to_state(self, savepoint)
             else:
                 self.committed_state.pop(key, None)
+             #   if savepoint:  # what about this?  
+             #       savepoint.pop(key, None)
 
         self.expired = False
         # unexpire attributes which have loaded

File lib/sqlalchemy/orm/session.py

View file
  • Ignore whitespace
             prop.merge(self, instance, merged, dont_load, _recursive)
 
         if dont_load:
+            # needs savepoint ?
             attributes.instance_state(merged).commit_all()  # remove any history
 
         if new_instance:

File test/orm/attr_rollback.py

View file
  • Ignore whitespace
         attributes.instance_state(f).commit_all()
 
         assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+    
+    def test_expire(self):
         
-    
+        class Test(object):
+            pass
+
+        called = [0]
+
+        attributes.register_class(Test)
+        self.register_attribute(Test, 'x')
+        if Test.x.impl.accepts_scalar_loader:
+            uses_scalar_loader = True
+            def load(state, keys):
+                called[0] += 1
+                state.dict['x'] = data2
+                state.commit(['x'])
+        
+            manager = attributes.manager_of_class(Test).deferred_scalar_loader = load
+        else:
+            # expiry uses the "lazy loaders" for object based attributes
+            uses_scalar_loader = False
+            def foo(state):
+                def load():
+                    called[0] += 1
+                    return data2
+                return load
+            Test.x.impl.callable_ = foo
+        
+        f = Test()
+
+        f.x = data1
+        if uses_scalar_loader:
+            assert called == [0]
+        else:
+            assert called == [1]
+
+#        import pdb
+#        pdb.set_trace()
+
+        attributes.instance_state(f).commit_all(savepoint_id=1)
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist1, [])
+        
+        attributes.instance_state(f).expire_attributes(['x'])
+        if uses_scalar_loader:
+            assert attributes.get_history(attributes.instance_state(f), 'x') == ([], emptyhist, [])
+            assert called == [0]
+        else:
+            assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+            assert called == [2]
+        
+        assert f.x == data2
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+        
+        f.x = data3
+        
+        attributes.instance_state(f).rollback(1)
+        assert attributes.get_history(attributes.instance_state(f), 'x') == ([], hist2, [])
+        if uses_scalar_loader:
+            assert called == [1]
+        else:
+            assert called == [2]
         
 class CollectionTestBase(AttrTestBase):
 
+    def register_attribute(self, cls, name, **kwargs):
+        attributes.register_attribute(cls, name, uselist=True, useobject=True, trackparent=True, typecallable=make_collection, **kwargs)
+
     def test_collection_rback_savepoint_rback_to_empty(self):
         b1, b2, b3, b4, b5 = Bar(), Bar(), Bar(), Bar(), Bar()
         f = Foo()
         
         for x in [b1, p1, p2, p3]:
             attributes.instance_state(x).commit_all(savepoint_id=1)
-        
+
         p4 = Post(name='p4')
         p5 = Post(name='p5')
         p4.blog = b1
         p5.blog = b1
-
+#        import pdb
+#        pdb.set_trace()
         assert b1.posts ==  make_collection([Post(name='p1'), Post(name='p2'), Post(name='p3'), Post(name='p4'), Post(name='p5')])
+        assert called == [1]
         assert attributes.get_history(attributes.instance_state(b1), 'posts') == ([p4, p5], [p1, p2, p3], [])
 
         for x in [b1, p1, p2, p3]:
         assert attributes.get_history(attributes.instance_state(b1), 'posts') == ([], [p1, p2, p3], [])
         assert b1.posts == make_collection([p1, p2, p3])
         
-        # TODO: more tests needed
+        # not sure how if we want lazy load to re-fire or not here
+#        assert called == [1]
 
 class ScalarTest(AttrTestBase, TestBase):
     def setUpAll(self):
         class Foo(object):pass
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'x', uselist=False, useobject=False)
-
+    
+    def register_attribute(self, cls, name, **kwargs):
+        attributes.register_attribute(cls, name, uselist=False, useobject=False, **kwargs)
+        
 class MutableScalarTest(AttrTestBase, TestBase):
     def setUpAll(self):
         global Foo, data1, data2, data3, hist1, hist2, hist3, empty, emptyhist
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'x', uselist=False, useobject=False, mutable_scalars=True, copy_function=dict)
 
+    def register_attribute(self, cls, name, **kwargs):
+        attributes.register_attribute(cls, name, uselist=False, useobject=False, mutable_scalars=True, copy_function=dict, **kwargs)
+
     def test_mutable1(self):
         f = Foo()
 
         empty = None
         emptyhist = [None]
 
+    def register_attribute(self, cls, name, **kwargs):
+        attributes.register_attribute(cls, name, uselist=False, useobject=True, trackparent=True, **kwargs)
+
     def test_hasparent(self):
         f = Foo()
         b = Bar()