Commits

Mike Bayer committed 5e5ffaf

- [bug] Fixed event registration bug
which would primarily show up as
events not being registered with
sessionmaker() instances created
after the event was associated
with the Session class. [ticket:2424]

Comments (0)

Files changed (5)

 0.7.6
 =====
 - orm
+  - [bug] Fixed event registration bug
+    which would primarily show up as
+    events not being registered with 
+    sessionmaker() instances created
+    after the event was associated
+    with the Session class.  [ticket:2424]
+
   - [feature] Added "no_autoflush" context
     manager to Session, used with with:
     will temporarily disable autoflush.

lib/sqlalchemy/event.py

 
 def listen(target, identifier, fn, *args, **kw):
     """Register a listener function for the given target.
-    
+
     e.g.::
-    
+
         from sqlalchemy import event
         from sqlalchemy.schema import UniqueConstraint
-        
+
         def unique_constraint_name(const, table):
             const.name = "uq_%s_%s" % (
                 table.name,
 
 def listens_for(target, identifier, *args, **kw):
     """Decorate a function as a listener for the given target + identifier.
-    
+
     e.g.::
-    
+
         from sqlalchemy import event
         from sqlalchemy.schema import UniqueConstraint
-        
+
         @event.listens_for(UniqueConstraint, "after_parent_attach")
         def unique_constraint_name(const, table):
             const.name = "uq_%s_%s" % (
     def insert(self, obj, target, propagate):
         assert isinstance(target, type), \
                 "Class-level Event targets must be classes."
-
         stack = [target]
         while stack:
             cls = stack.pop(0)
             stack.extend(cls.__subclasses__())
-            self._clslevel[cls].insert(0, obj)
+            if cls is not target and cls not in self._clslevel:
+                self.update_subclass(cls)
+            else:
+                self._clslevel[cls].insert(0, obj)
 
     def append(self, obj, target, propagate):
         assert isinstance(target, type), \
         while stack:
             cls = stack.pop(0)
             stack.extend(cls.__subclasses__())
-            self._clslevel[cls].append(obj)
+            if cls is not target and cls not in self._clslevel:
+                self.update_subclass(cls)
+            else:
+                self._clslevel[cls].append(obj)
+
+    def update_subclass(self, target):
+        clslevel = self._clslevel[target]
+        for cls in target.__mro__[1:]:
+            if cls in self._clslevel:
+                clslevel.extend([
+                    fn for fn 
+                    in self._clslevel[cls] 
+                    if fn not in clslevel
+                ])
 
     def remove(self, obj, target):
         stack = [target]
     _exec_once = False
 
     def __init__(self, parent, target_cls):
+        if target_cls not in parent._clslevel:
+            parent.update_subclass(target_cls)
         self.parent_listeners = parent._clslevel[target_cls]
         self.name = parent.__name__
         self.listeners = []

lib/sqlalchemy/orm/session.py

             kwargs.update(new_kwargs)
 
 
-    return type("Session", (Sess, class_), {})
+    return type("SessionMaker", (Sess, class_), {})
 
 
 class SessionTransaction(object):

test/base/test_events.py

             [listen_two]
         )
 
+class TestClsLevelListen(fixtures.TestBase):
+    def setUp(self):
+        class TargetEventsOne(event.Events):
+            def event_one(self, x, y):
+                pass
+        class TargetOne(object):
+            dispatch = event.dispatcher(TargetEventsOne)
+        self.TargetOne = TargetOne
+
+    def tearDown(self):
+        event._remove_dispatcher(
+            self.TargetOne.__dict__['dispatch'].events)
+
+    def test_lis_subcalss_lis(self):
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler1(x, y):
+            print 'handler1'
+
+        class SubTarget(self.TargetOne):
+            pass
+
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler2(x, y):
+            pass
+
+        eq_(
+            len(SubTarget().dispatch.event_one),
+            2
+        )
+
+    def test_lis_multisub_lis(self):
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler1(x, y):
+            print 'handler1'
+
+        class SubTarget(self.TargetOne):
+            pass
+
+        class SubSubTarget(SubTarget):
+            pass
+
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler2(x, y):
+            pass
+
+        eq_(
+            len(SubTarget().dispatch.event_one),
+            2
+        )
+        eq_(
+            len(SubSubTarget().dispatch.event_one),
+            2
+        )
+
+    def test_two_sub_lis(self):
+        class SubTarget1(self.TargetOne):
+            pass
+        class SubTarget2(self.TargetOne):
+            pass
+
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler1(x, y):
+            pass
+        @event.listens_for(SubTarget1, "event_one")
+        def handler2(x, y):
+            pass
+
+        s1 = SubTarget1()
+        assert handler1 in s1.dispatch.event_one
+        assert handler2 in s1.dispatch.event_one
+
+        s2 = SubTarget2()
+        assert handler1 in s2.dispatch.event_one
+        assert handler2 not in s2.dispatch.event_one
+
+
+class TestClsLevelListen(fixtures.TestBase):
+    def setUp(self):
+        class TargetEventsOne(event.Events):
+            def event_one(self, x, y):
+                pass
+        class TargetOne(object):
+            dispatch = event.dispatcher(TargetEventsOne)
+        self.TargetOne = TargetOne
+
+    def tearDown(self):
+        event._remove_dispatcher(self.TargetOne.__dict__['dispatch'].events)
+
+    def test_lis_subcalss_lis(self):
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler1(x, y):
+            print 'handler1'
+
+        class SubTarget(self.TargetOne):
+            pass
+
+        @event.listens_for(self.TargetOne, "event_one")
+        def handler2(x, y):
+            pass
+
+        eq_(
+            len(SubTarget().dispatch.event_one),
+            2
+        )
 class TestAcceptTargets(fixtures.TestBase):
     """Test default target acceptance."""
 

test/orm/test_events.py

         sess.flush()
         eq_(canary,
             ['init', 'before_insert',
-             'after_insert', 'expire', 'translate_row', 'populate_instance',
-             'refresh',
+             'after_insert', 'expire', 'translate_row', 
+             'populate_instance', 'refresh',
              'append_result', 'translate_row', 'create_instance',
              'populate_instance', 'load', 'append_result',
-             'before_update', 'after_update', 'before_delete', 'after_delete'])
+             'before_update', 'after_update', 'before_delete', 
+             'after_delete'])
 
     def test_merge(self):
         users, User = self.tables.users, self.classes.User
 
         """
 
-        keywords, items, item_keywords, Keyword, Item = (self.tables.keywords,
+        keywords, items, item_keywords, Keyword, Item = (
+                                self.tables.keywords,
                                 self.tables.items,
                                 self.tables.item_keywords,
                                 self.classes.Keyword,
         assert my_listener in s.dispatch.before_flush
 
     def test_sessionmaker_listen(self):
-        """test that listen can be applied to individual scoped_session() classes."""
+        """test that listen can be applied to individual 
+        scoped_session() classes."""
 
         def my_listener_one(*arg, **kw):
             pass
 
         mapper(User, users)
 
-        sess, canary = self._listener_fixture(autoflush=False, autocommit=True, expire_on_commit=False)
+        sess, canary = self._listener_fixture(autoflush=False, 
+                            autocommit=True, expire_on_commit=False)
 
         u = User(name='u1')
         sess.add(u)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.