Commits

jason kirtland committed c4e28bb

- unrolled loops for the simplified Session.get_bind() args
- restored the chunk of test r4806 deleted (!)

Comments (0)

Files changed (2)

lib/sqlalchemy/orm/session.py

         clause
           Optional, any ``ClauseElement``
 
-        instance
-          Optional, an instance of a mapped class
-
         """
         return self.__connection(self.get_bind(mapper, clause, _state))
 
 
         _state
           Optional, SA internal representation of a mapped instance
-            
+
         """
         if mapper is clause is _state is None:
             if self.bind:
                     "Connection, and no context was provided to locate "
                     "a binding.")
 
-        mappers = []
-        if _state is not None:
-            mappers.append(_state_mapper(_state))
-        if mapper is not None:
-            mappers.append(_class_to_mapper(mapper))
+        s_mapper = _state is not None and _state_mapper(_state) or None
+        c_mapper = mapper is not None and _class_to_mapper(mapper) or None
 
         # manually bound?
         if self.__binds:
-            for m in mappers:
-                if m.base_mapper in self.__binds:
-                    return self.__binds[m.base_mapper]
-                elif m.mapped_table in self.__binds:
-                    return self.__binds[m.mapped_table]
+            if s_mapper:
+                if s_mapper.base_mapper in self.__binds:
+                    return self.__binds[s_mapper.base_mapper]
+                elif s_mapper.mapped_table in self.__binds:
+                    return self.__binds[s_mapper.mapped_table]
+            if c_mapper:
+                if c_mapper.base_mapper in self.__binds:
+                    return self.__binds[c_mapper.base_mapper]
+                elif c_mapper.mapped_table in self.__binds:
+                    return self.__binds[c_mapper.mapped_table]
             if clause:
                 for t in sql_util.find_tables(clause):
                     if t in self.__binds:
         if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
             return clause.bind
 
-        for m in mappers:
-            if m.mapped_table.bind:
-                return m.mapped_table.bind
+        if s_mapper and s_mapper.mapped_table.bind:
+            return s_mapper.mapped_table.bind
+        if c_mapper and c_mapper.mapped_table.bind:
+            return c_mapper.mapped_table.bind
 
         context = []
         if mapper is not None:
-            context.append('mapper %s' % _class_to_mapper(mapper))
+            context.append('mapper %s' % c_mapper)
         if clause is not None:
             context.append('SQL expression')
         if _state is not None:

test/orm/session.py

 
     # TODO: expand with message body assertions.
 
-    _class_methods = set(('get', 'load'))
+    _class_methods = set((
+        'connection', 'execute', 'get', 'get_bind', 'load', 'scalar'))
 
     def _public_session_methods(self):
         Session = sa.orm.session.Session
 
-        blacklist = set(('begin', 'query', 'connection', 'execute', 'get_bind', 'scalar'))
+        blacklist = set(('begin', 'query'))
 
         ok = set()
         for meth in Session.public_methods:
             self.assertRaises(sa.orm.exc.UnmappedClassError,
                               callable_, *args, **kw)
 
+        raises_('connection', mapper=user_arg)
+
+        raises_('execute', 'SELECT 1', mapper=user_arg)
+
         raises_('get', user_arg, 1)
 
+        raises_('get_bind', mapper=user_arg)
+
         raises_('load', user_arg, 1)
 
+        raises_('scalar', 'SELECT 1', mapper=user_arg)
+
         eq_(watchdog, self._class_methods,
             watchdog.symmetric_difference(self._class_methods))