Commits

Mike Bayer committed df96868 Draft

- simplify setup_entity and related calls
- break _compile_context() into three methods

  • Participants
  • Parent commits 854d287

Comments (0)

Files changed (1)

File lib/sqlalchemy/orm/query.py

             for entity in ent.entities:
                 if entity not in d:
                     ext_info = _extended_entity_info(entity)
-                    if not ext_info.is_aliased_class and ext_info.mapper.with_polymorphic:
+                    if not ext_info.is_aliased_class and \
+                        ext_info.mapper.with_polymorphic:
                         if ext_info.mapper.mapped_table not in \
                                             self._polymorphic_adapters:
                             self._mapper_loads_polymorphically_with(ext_info.mapper, 
                     else:
                         aliased_adapter = None
 
-                    d[entity] = (ext_info.mapper, aliased_adapter, ext_info.selectable, 
-                                        ext_info.is_aliased_class, ext_info.with_polymorphic_mappers,
-                                        ext_info.with_polymorphic_discriminator)
-                ent.setup_entity(entity, *d[entity])
+                    d[entity] = (
+                        ext_info,
+                        aliased_adapter
+                    )
+                ent.setup_entity(*d[entity])
 
     def _mapper_loads_polymorphically_with(self, mapper, adapter):
         for m2 in mapper._with_polymorphic_mappers:
         to count, to skip the usage of a subquery or
         otherwise control of the FROM clause,
         or to use other aggregate functions,
-        use :attr:`~sqlalchemy.sql.expression.func` expressions in conjunction
+        use :attr:`~sqlalchemy.sql.expression.func` 
+        expressions in conjunction
         with :meth:`~.Session.query`, i.e.::
 
             from sqlalchemy import func
         """
         #TODO: cascades need handling.
 
-        delete_op = persistence.BulkDelete.factory(self, synchronize_session)
+        delete_op = persistence.BulkDelete.factory(
+                            self, synchronize_session)
         delete_op.exec_()
         return delete_op.rowcount
 
         # fk assignments
         #TODO: cascades need handling.
 
-        update_op = persistence.BulkUpdate.factory(self, synchronize_session, values)
+        update_op = persistence.BulkUpdate.factory(
+                            self, synchronize_session, values)
         update_op.exec_()
         return update_op.rowcount
 
 
+    _lockmode_lookup = {
+            'read': 'read',
+              'read_nowait': 'read_nowait',
+              'update': True,
+              'update_nowait': 'nowait',
+              None: False
+    }
+
     def _compile_context(self, labels=True):
         context = QueryContext(self)
 
         if context.statement is not None:
             return context
 
+        context.labels = labels
+
         if self._lockmode:
             try:
-                for_update = {'read': 'read',
-                              'read_nowait': 'read_nowait',
-                              'update': True,
-                              'update_nowait': 'nowait',
-                              None: False}[self._lockmode]
+                context.for_update = self._lockmode_lookup[self._lockmode]
             except KeyError:
                 raise sa_exc.ArgumentError(
-                            "Unknown lockmode %r" % self._lockmode)
-        else:
-            for_update = False
-
+                                "Unknown lockmode %r" % self._lockmode)
         for entity in self._entities:
             entity.setup_context(self, context)
 
         if context.from_clause:
             # "load from explicit FROMs" mode, 
             # i.e. when select_from() or join() is used
-            froms = list(context.from_clause)
+            context.froms = list(context.from_clause)
         else:
             # "load from discrete FROMs" mode, 
             # i.e. when each _MappedEntity has its own FROM
-            froms = context.froms
+            context.froms = context.froms
 
         if self._enable_single_crit:
             self._adjust_for_single_inheritance(context)
                             "SELECT from.")
 
         if context.multi_row_eager_loaders and self._should_nest_selectable:
-            # for eager joins present and LIMIT/OFFSET/DISTINCT, 
-            # wrap the query inside a select,
-            # then append eager joins onto that
-
-            if context.order_by:
-                order_by_col_expr = list(
-                                        chain(*[
-                                            sql_util.unwrap_order_by(o)
-                                            for o in context.order_by
-                                        ])
-                                    )
-            else:
-                context.order_by = None
-                order_by_col_expr = []
-
-            inner = sql.select(
-                        context.primary_columns + order_by_col_expr,
+            context.statement = self._compound_eager_statement(context)
+        else:
+            context.statement = self._simple_statement(context)
+        return context
+
+    def _compound_eager_statement(self, context):
+        # for eager joins present and LIMIT/OFFSET/DISTINCT, 
+        # wrap the query inside a select,
+        # then append eager joins onto that
+
+        if context.order_by:
+            order_by_col_expr = list(
+                                    chain(*[
+                                        sql_util.unwrap_order_by(o)
+                                        for o in context.order_by
+                                    ])
+                                )
+        else:
+            context.order_by = None
+            order_by_col_expr = []
+
+        inner = sql.select(
+                    context.primary_columns + order_by_col_expr,
+                    context.whereclause,
+                    from_obj=context.froms,
+                    use_labels=context.labels,
+                    # TODO: this order_by is only needed if 
+                    # LIMIT/OFFSET is present in self._select_args,
+                    # else the application on the outside is enough
+                    order_by=context.order_by,
+                    **self._select_args
+                )
+
+        for hint in self._with_hints:
+            inner = inner.with_hint(*hint)
+
+        if self._correlate:
+            inner = inner.correlate(*self._correlate)
+
+        inner = inner.alias()
+
+        equivs = self.__all_equivs()
+
+        context.adapter = sql_util.ColumnAdapter(inner, equivs)
+
+        statement = sql.select(
+                            [inner] + context.secondary_columns, 
+                            for_update=context.for_update, 
+                            use_labels=context.labels)
+
+        from_clause = inner
+        for eager_join in context.eager_joins.values():
+            # EagerLoader places a 'stop_on' attribute on the join,
+            # giving us a marker as to where the "splice point" of 
+            # the join should be
+            from_clause = sql_util.splice_joins(
+                                        from_clause, 
+                                        eager_join, eager_join.stop_on)
+
+        statement.append_from(from_clause)
+
+        if context.order_by:
+            statement.append_order_by(
+                *context.adapter.copy_and_process(
+                    context.order_by
+                )
+            )
+
+        statement.append_order_by(*context.eager_order_by)
+        return statement
+
+    def _simple_statement(self, context):
+        if not context.order_by:
+            context.order_by = None
+
+        if self._distinct and context.order_by:
+            order_by_col_expr = list(
+                                    chain(*[
+                                        sql_util.unwrap_order_by(o) 
+                                        for o in context.order_by
+                                    ])
+                                )
+            context.primary_columns += order_by_col_expr
+
+        context.froms += tuple(context.eager_joins.values())
+
+        statement = sql.select(
+                        context.primary_columns +
+                                context.secondary_columns,
                         context.whereclause,
-                        from_obj=froms,
-                        use_labels=labels,
-                        # TODO: this order_by is only needed if 
-                        # LIMIT/OFFSET is present in self._select_args,
-                        # else the application on the outside is enough
+                        from_obj=context.froms,
+                        use_labels=context.labels,
+                        for_update=context.for_update,
                         order_by=context.order_by,
                         **self._select_args
                     )
 
-            for hint in self._with_hints:
-                inner = inner.with_hint(*hint)
-
-            if self._correlate:
-                inner = inner.correlate(*self._correlate)
-
-            inner = inner.alias()
-
-            equivs = self.__all_equivs()
-
-            context.adapter = sql_util.ColumnAdapter(inner, equivs)
-
-            statement = sql.select(
-                                [inner] + context.secondary_columns, 
-                                for_update=for_update, 
-                                use_labels=labels)
-
-            from_clause = inner
-            for eager_join in eager_joins:
-                # EagerLoader places a 'stop_on' attribute on the join,
-                # giving us a marker as to where the "splice point" of 
-                # the join should be
-                from_clause = sql_util.splice_joins(
-                                            from_clause, 
-                                            eager_join, eager_join.stop_on)
-
-            statement.append_from(from_clause)
-
-            if context.order_by:
-                statement.append_order_by(
-                    *context.adapter.copy_and_process(
-                        context.order_by
-                    )
-                )
-
+        for hint in self._with_hints:
+            statement = statement.with_hint(*hint)
+
+        if self._correlate:
+            statement = statement.correlate(*self._correlate)
+
+        if context.eager_order_by:
             statement.append_order_by(*context.eager_order_by)
-        else:
-            if not context.order_by:
-                context.order_by = None
-
-            if self._distinct and context.order_by:
-                order_by_col_expr = list(
-                                        chain(*[
-                                            sql_util.unwrap_order_by(o) 
-                                            for o in context.order_by
-                                        ])
-                                    )
-                context.primary_columns += order_by_col_expr
-
-            froms += tuple(context.eager_joins.values())
-
-            statement = sql.select(
-                            context.primary_columns +
-                                    context.secondary_columns,
-                            context.whereclause,
-                            from_obj=froms,
-                            use_labels=labels,
-                            for_update=for_update,
-                            order_by=context.order_by,
-                            **self._select_args
-                        )
-
-            for hint in self._with_hints:
-                statement = statement.with_hint(*hint)
-
-            if self._correlate:
-                statement = statement.correlate(*self._correlate)
-
-            if context.eager_order_by:
-                statement.append_order_by(*context.eager_order_by)
-
-        context.statement = statement
-
-        return context
+        return statement
+
 
     def _adjust_for_single_inheritance(self, context):
         """Apply single-table-inheritance filtering.
-
-        For all distinct single-table-inheritance mappers represented in the
-        columns clause of this query, add criterion to the WHERE clause of the
-        given QueryContext such that only the appropriate subtypes are
-        selected from the total results.
-
+        
+        For all distinct single-table-inheritance mappers represented in
+        the columns clause of this query, add criterion to the WHERE
+        clause of the given QueryContext such that only the appropriate
+        subtypes are selected from the total results.
+        
         """
-        for entity, (mapper, adapter, s, i, w, d) in \
-                            self._mapper_adapter_map.iteritems():
-            if entity in self._join_entities:
+        for (ext_info, adapter) in self._mapper_adapter_map.values():
+            if ext_info.entity in self._join_entities:
                 continue
-            single_crit = mapper._single_table_criterion
+            single_crit = ext_info.mapper._single_table_criterion
             if single_crit is not None:
                 if adapter:
                     single_crit = adapter.traverse(single_crit)
                 single_crit = self._adapt_clause(single_crit, False, False)
-                context.whereclause = sql.and_(
-                                            context.whereclause, single_crit)
+                context.whereclause = sql.and_(context.whereclause, 
+                                            single_crit)
 
     def __str__(self):
         return str(self._compile_context().statement)
         self.entities = [entity]
         self.expr = entity
 
-    def setup_entity(self, entity, mapper, aliased_adapter, 
-                        from_obj, is_aliased_class, 
-                        with_polymorphic,
-                        with_polymorphic_discriminator):
-        self.mapper = mapper
+    def setup_entity(self, ext_info, aliased_adapter):
+        self.mapper = ext_info.mapper
         self.aliased_adapter = aliased_adapter
-        self.selectable  = from_obj
-        self.is_aliased_class = is_aliased_class
-        self._with_polymorphic = with_polymorphic
-        self._polymorphic_discriminator = with_polymorphic_discriminator
-        if is_aliased_class:
-            self.entity_zero = entity
+        self.selectable  = ext_info.selectable
+        self.is_aliased_class = ext_info.is_aliased_class
+        self._with_polymorphic = ext_info.with_polymorphic_mappers
+        self._polymorphic_discriminator = \
+                ext_info.with_polymorphic_discriminator
+        if ext_info.is_aliased_class:
+            self.entity_zero = ext_info.entity
             self._label_name = self.entity_zero._sa_label_name
         else:
-            self.entity_zero = mapper
+            self.entity_zero = self.mapper
             self._label_name = self.mapper.class_.__name__
         self.path = self.entity_zero._sa_path_registry
 
         c.entity_zero = self.entity_zero
         c.entities = self.entities
 
-    def setup_entity(self, entity, mapper, adapter, from_obj,
-                                is_aliased_class, with_polymorphic,
-                                with_polymorphic_discriminator):
+    def setup_entity(self, ext_info, aliased_adapter):
         if 'selectable' not in self.__dict__: 
-            self.selectable = from_obj
-        self.froms.add(from_obj)
+            self.selectable = ext_info.selectable
+        self.froms.add(ext_info.selectable)
 
     def corresponds_to(self, entity):
         if self.entity_zero is None:
     multi_row_eager_loaders = False
     adapter = None
     froms = ()
+    for_update = False
 
     def __init__(self, query):