Commits

Mike Bayer  committed 0848576

- AttributeListener has been refined such that the event
is fired before the mutation actually occurs. Addtionally,
the append() and set() methods must now return the given value,
which is used as the value to be used in the mutation operation.
This allows creation of validating AttributeListeners which
raise before the action actually occurs, and which can change
the given value into something else before its used.
A new example "validate_attributes.py" shows one such recipe
for doing this. AttributeListener helper functions are
also on the way.

  • Participants
  • Parent commits 7e330f2

Comments (0)

Files changed (9)

       clause will appear in the WHERE clause of the query as well
       since this discrimination has multiple trigger points.
 
+    - AttributeListener has been refined such that the event
+      is fired before the mutation actually occurs.  Addtionally,
+      the append() and set() methods must now return the given value,
+      which is used as the value to be used in the mutation operation.
+      This allows creation of validating AttributeListeners which
+      raise before the action actually occurs, and which can change
+      the given value into something else before its used.
+      A new example "validate_attributes.py" shows one such recipe
+      for doing this.   AttributeListener helper functions are
+      also on the way.
+      
     - class.someprop.in_() raises NotImplementedError pending the
       implementation of "in_" for relation [ticket:1140]
 

File examples/custom_attributes/listen_for_events.py

 
 class InstallListeners(InstrumentationManager):
     def instrument_attribute(self, class_, key, inst):
-        """Add an event listener to all InstrumentedAttributes."""
+        """Add an event listener to an InstrumentedAttribute."""
         
-        inst.impl.extensions.append(AttributeListener(key))
+        inst.impl.extensions.insert(0, AttributeListener(key))
         return super(InstallListeners, self).instrument_attribute(class_, key, inst)
         
 class AttributeListener(AttributeExtension):
     
     def append(self, state, value, initiator):
         self._report(state, value, None, "appended")
+        return value
 
     def remove(self, state, value, initiator):
         self._report(state, value, None, "removed")
 
     def set(self, state, value, oldvalue, initiator):
         self._report(state, value, oldvalue, "set")
+        return value
     
     def _report(self, state, value, oldvalue, verb):
         state.obj().receive_change_event(verb, self.key, value, oldvalue)

File examples/custom_attributes/validate_attributes.py

+"""
+Illustrates how to use AttributeExtension to create attribute validators.
+
+"""
+
+from sqlalchemy.orm.interfaces import AttributeExtension, InstrumentationManager
+
+class InstallValidators(InstrumentationManager):
+    """Searches a class for methods with a '_validates' attribute and assembles Validators."""
+    
+    def __init__(self, cls):
+        self.validators = {}
+        for k in dir(cls):
+            item = getattr(cls, k)
+            if hasattr(item, '_validates'):
+                self.validators[item._validates] = item
+                
+    def instrument_attribute(self, class_, key, inst):
+        """Add an event listener to an InstrumentedAttribute."""
+        
+        if key in self.validators:
+            inst.impl.extensions.insert(0, Validator(key, self.validators[key]))
+        return super(InstallValidators, self).instrument_attribute(class_, key, inst)
+        
+class Validator(AttributeExtension):
+    """Validates an attribute, given the key and a validation function."""
+    
+    def __init__(self, key, validator):
+        self.key = key
+        self.validator = validator
+    
+    def append(self, state, value, initiator):
+        return self.validator(state.obj(), value)
+
+    def set(self, state, value, oldvalue, initiator):
+        return self.validator(state.obj(), value)
+
+def validates(key):
+    """Mark a method as validating a named attribute."""
+    
+    def wrap(fn):
+        fn._validates = key
+        return fn
+    return wrap
+
+if __name__ == '__main__':
+
+    from sqlalchemy import *
+    from sqlalchemy.orm import *
+    from sqlalchemy.ext.declarative import declarative_base
+    import datetime
+    
+    Base = declarative_base(engine=create_engine('sqlite://', echo=True))
+    Base.__sa_instrumentation_manager__ = InstallValidators
+
+    class MyMappedClass(Base):
+        __tablename__ = "mytable"
+    
+        id = Column(Integer, primary_key=True)
+        date = Column(Date)
+        related_id = Column(Integer, ForeignKey("related.id"))
+        related = relation("Related", backref="mapped")
+
+        @validates('date')
+        def check_date(self, value):
+            if isinstance(value, str):
+                m, d, y = [int(x) for x in value.split('/')]
+                return datetime.date(y, m, d)
+            else:
+                assert isinstance(value, datetime.date)
+                return value
+        
+        @validates('related')
+        def check_related(self, value):
+            assert value.data == 'r1'
+            return value
+            
+        def __str__(self):
+            return "MyMappedClass(date=%r)" % self.date
+            
+    class Related(Base):
+        __tablename__ = "related"
+
+        id = Column(Integer, primary_key=True)
+        data = Column(String(50))
+
+        def __str__(self):
+            return "Related(data=%r)" % self.data
+    
+    Base.metadata.create_all()
+    session = sessionmaker()()
+    
+    r1 = Related(data='r1')
+    r2 = Related(data='r2')
+    m1 = MyMappedClass(date='5/2/2005', related=r1)
+    m2 = MyMappedClass(date=datetime.date(2008, 10, 15))
+    r1.mapped.append(m2)
+
+    try:
+        m1.date = "this is not a date"
+    except:
+        pass
+    assert m1.date == datetime.date(2005, 5, 2)
+    
+    try:
+        m2.related = r2
+    except:
+        pass
+    assert m2.related is r1
+    
+    session.add(m1)
+    session.commit()
+    assert session.query(MyMappedClass.date).order_by(MyMappedClass.date).all() == [
+        (datetime.date(2005, 5, 2),),
+        (datetime.date(2008, 10, 15),)
+    ]
+    

File lib/sqlalchemy/orm/attributes.py

         state.modified_event(self, False, old)
 
         if self.extensions:
+            self.fire_remove_event(state, old, None)
             del state.dict[self.key]
-            self.fire_remove_event(state, old, None)
         else:
             del state.dict[self.key]
 
         state.modified_event(self, False, old)
 
         if self.extensions:
+            value = self.fire_replace_event(state, value, old, initiator)
             state.dict[self.key] = value
-            self.fire_replace_event(state, value, old, initiator)
         else:
             state.dict[self.key] = value
 
     def fire_replace_event(self, state, value, previous, initiator):
         for ext in self.extensions:
-            ext.set(state, value, previous, initiator or self)
+            value = ext.set(state, value, previous, initiator or self)
+        return value
 
     def fire_remove_event(self, state, value, initiator):
         for ext in self.extensions:
 
         if self.extensions:
             old = self.get(state)
+            value = self.fire_replace_event(state, value, old, initiator)
             state.dict[self.key] = value
-            self.fire_replace_event(state, value, old, initiator)
         else:
             state.dict[self.key] = value
 
 
     def delete(self, state):
         old = self.get(state)
-        # TODO: catch key errors, convert to attributeerror?
+        self.fire_remove_event(state, old, self)
         del state.dict[self.key]
-        self.fire_remove_event(state, old, self)
 
     def get_history(self, state, passive=False):
         if self.key in state.dict:
 
         # may want to add options to allow the get() here to be passive
         old = self.get(state)
+        value = self.fire_replace_event(state, value, old, initiator)
         state.dict[self.key] = value
-        self.fire_replace_event(state, value, old, initiator)
 
     def fire_remove_event(self, state, value, initiator):
         state.modified_event(self, False, value)
                 self.sethasparent(instance_state(previous), False)
 
         for ext in self.extensions:
-            ext.set(state, value, previous, initiator or self)
+            value = ext.set(state, value, previous, initiator or self)
+        return value
 
 
 class CollectionAttributeImpl(AttributeImpl):
             self.sethasparent(instance_state(value), True)
 
         for ext in self.extensions:
-            ext.append(state, value, initiator or self)
+            value = ext.append(state, value, initiator or self)
+        return value
 
     def fire_pre_remove_event(self, state, initiator):
         state.modified_event(self, True, NEVER_SET, passive=True)
 
         collection = self.get_collection(state, passive=passive)
         if collection is PASSIVE_NORESULT:
+            value = self.fire_append_event(state, value, initiator)
             state.get_pending(self.key).append(value)
-            self.fire_append_event(state, value, initiator)
         else:
             collection.append_with_event(value, initiator)
 
 
         collection = self.get_collection(state, passive=passive)
         if collection is PASSIVE_NORESULT:
+            self.fire_remove_event(state, value, initiator)
             state.get_pending(self.key).remove(value)
-            self.fire_remove_event(state, value, initiator)
         else:
             collection.remove_with_event(value, initiator)
 
 
     def set(self, state, child, oldchild, initiator):
         if oldchild is child:
-            return
+            return child
         if oldchild is not None:
             # With lazy=None, there's no guarantee that the full collection is
             # present when updating via a backref.
         if child is not None:
             new_state = instance_state(child)
             new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True)
-
+        return child
+        
     def append(self, state, child, initiator):
         child_state = instance_state(child)
         child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True)
-
+        return child
+        
     def remove(self, state, child, initiator):
         if child is not None:
             child_state = instance_state(child)

File lib/sqlalchemy/orm/collections.py

 
         """
         if initiator is not False and item is not None:
-            self.attr.fire_append_event(self.owner_state, item, initiator)
+            return self.attr.fire_append_event(self.owner_state, item, initiator)
+        else:
+            return item
 
     def fire_remove_event(self, item, initiator=None):
         """Notify that a entity has been removed from the collection.
 
 def __set(collection, item, _sa_initiator=None):
     """Run set events, may eventually be inlined into decorators."""
+
     if _sa_initiator is not False and item is not None:
         executor = getattr(collection, '_sa_adapter', None)
         if executor:
-            getattr(executor, 'fire_append_event')(item, _sa_initiator)
-
+            item = getattr(executor, 'fire_append_event')(item, _sa_initiator)
+    return item
+    
 def __del(collection, item, _sa_initiator=None):
     """Run del events, may eventually be inlined into decorators."""
     if _sa_initiator is not False and item is not None:
 
     def append(fn):
         def append(self, item, _sa_initiator=None):
-            __set(self, item, _sa_initiator)
+            item = __set(self, item, _sa_initiator)
             fn(self, item)
         _tidy(append)
         return append
 
     def insert(fn):
         def insert(self, index, value):
-            __set(self, value)
+            value = __set(self, value)
             fn(self, index, value)
         _tidy(insert)
         return insert
                 existing = self[index]
                 if existing is not None:
                     __del(self, existing)
-                __set(self, value)
+                value = __set(self, value)
                 fn(self, index, value)
             else:
                 # slice assignment requires __delitem__, insert, __len__
         def __setslice__(self, start, end, values):
             for value in self[start:end]:
                 __del(self, value)
-            for value in values:
-                __set(self, value)
+            values = [__set(self, value) for value in values]
             fn(self, start, end, values)
         _tidy(__setslice__)
         return __setslice__
         def __setitem__(self, key, value, _sa_initiator=None):
             if key in self:
                 __del(self, self[key], _sa_initiator)
-            __set(self, value, _sa_initiator)
+            value = __set(self, value, _sa_initiator)
             fn(self, key, value)
         _tidy(__setitem__)
         return __setitem__
     def add(fn):
         def add(self, value, _sa_initiator=None):
             if value not in self:
-                __set(self, value, _sa_initiator)
+                value = __set(self, value, _sa_initiator)
             # testlib.pragma exempt:__hash__
             fn(self, value)
         _tidy(add)

File lib/sqlalchemy/orm/interfaces.py

     """An event handler for individual attribute change events.
     
     AttributeExtension is assembled within the descriptors associated 
-    with a mapped class.
+    with a mapped class. 
     
     """
 
     def append(self, state, value, initiator):
-        pass
+        """Receive a collection append event.
+        
+        The returned value will be used as the actual value to be
+        appended.
+        
+        """
+        return value
 
     def remove(self, state, value, initiator):
+        """Receive a remove event.
+        
+        No return value is defined.
+        
+        """
         pass
 
     def set(self, state, value, oldvalue, initiator):
-        pass
+        """Receive a set event.
+        
+        The returned value will be used as the actual value to be
+        set.
+        
+        """
+        return value
 
 
 class StrategizedOption(PropertyOption):

File lib/sqlalchemy/orm/unitofwork.py

             prop = _state_mapper(state).get_property(self.key)
             if prop.cascade.save_update and item not in sess:
                 sess.save_or_update(item)
-
+        return item
+        
     def remove(self, state, item, initiator):
         sess = _state_session(state)
         if sess:
     def set(self, state, newvalue, oldvalue, initiator):
         # process "save_update" cascade rules for when an instance is attached to another instance
         if oldvalue is newvalue:
-            return
+            return newvalue
         sess = _state_session(state)
         if sess:
             prop = _state_mapper(state).get_property(self.key)
                 sess.save_or_update(newvalue)
             if prop.cascade.delete_orphan and oldvalue in sess.new:
                 sess.expunge(oldvalue)
-
+        return newvalue
 
 def register_attribute(class_, key, *args, **kwargs):
     """overrides attributes.register_attribute() to add UOW event handlers

File test/orm/attributes.py

 
             def set(self, state, child, oldchild, initiator):
                 results.append(("set", state.obj(), child, oldchild))
+                return child
         
         attributes.register_class(Foo)
         attributes.register_attribute(Foo, 'x', uselist=False, mutable_scalars=False, useobject=False, extension=ReceiveEvents())
         assert f.bar is None
         eq_(attributes.get_history(attributes.instance_state(f), 'bar'), ([None], (), [bar1]))
 
+class ListenerTest(_base.ORMTest):
+    def test_receive_changes(self):
+        """test that Listeners can mutate the given value.
+        
+        This is a rudimentary test which would be better suited by a full-blown inclusion
+        into collection.py.
+        
+        """
+        class Foo(object):
+            pass
+        class Bar(object):
+            pass
+
+        class AlteringListener(AttributeExtension):
+            def append(self, state, child, initiator):
+                b2 = Bar()
+                b2.data = b1.data + " appended"
+                return b2
+
+            def set(self, state, value, oldvalue, initiator):
+                return value + " modified"
+
+        attributes.register_class(Foo)
+        attributes.register_class(Bar)
+        attributes.register_attribute(Foo, 'data', uselist=False, useobject=False, extension=AlteringListener())
+        attributes.register_attribute(Foo, 'barlist', uselist=True, useobject=True, extension=AlteringListener())
+        attributes.register_attribute(Foo, 'barset', typecallable=set, uselist=True, useobject=True, extension=AlteringListener())
+        attributes.register_attribute(Bar, 'data', uselist=False, useobject=False)
+        
+        f1 = Foo()
+        f1.data = "some data"
+        eq_(f1.data, "some data modified")
+        b1 = Bar()
+        b1.data = "some bar"
+        f1.barlist.append(b1)
+        assert b1.data == "some bar"
+        assert f1.barlist[0].data == "some bar appended"
+        
+        f1.barset.add(b1)
+        assert f1.barset.pop().data == "some bar appended"
+    
     
 if __name__ == "__main__":
     testenv.main()

File test/orm/collection.py

         assert value not in self.added
         self.data.add(value)
         self.added.add(value)
+        return value
     def remove(self, obj, value, initiator):
         assert value not in self.removed
         self.data.remove(value)
         self.removed.add(value)
     def set(self, obj, value, oldvalue, initiator):
+        if isinstance(value, str):
+            value = CollectionsTest.entity_maker()
+
         if oldvalue is not None:
             self.remove(obj, oldvalue, None)
         self.append(obj, value, None)
-
+        return value
 
 class CollectionsTest(_base.ORMTest):
     class Entity(object):