Commits

Mike Bayer committed 72f67ec

- removed _calculate_correlations() methods, removed correlation_stack, select_stack;
all are merged into a single stack thats all within ansicompiler. clause visiting cut down
significantly.

  • Participants
  • Parent commits 77fd6f8

Comments (0)

Files changed (2)

File lib/sqlalchemy/ansisql.py

         # actually present in the generated SQL
         self.bind_names = {}
 
-        # when the compiler visits a SELECT statement, the clause object is appended
-        # to this stack.  various visit operations will check this stack to determine
-        # additional choices (TODO: it seems to be all typemap stuff.  shouldnt this only
-        # apply to the topmost-level SELECT statement ?)
-        self.select_stack = []
-
+        # a stack.  what recursive compiler doesn't have a stack ? :)
+        self.stack = []
+        
         # a dictionary of result-set column names (strings) to TypeEngine instances,
         # which will be passed to a ResultProxy and used for resultset-level value conversion
         self.typemap = {}
         # an ANSIIdentifierPreparer that formats the quoting of identifiers
         self.preparer = dialect.identifier_preparer
         
-        # a dictionary containing attributes about all select()
-        # elements located within the clause, regarding which are subqueries, which are
-        # selected from, and which elements should be correlated to an enclosing select.
-        # used mostly to determine the list of FROM elements for each select statement, as well
-        # as some dialect-specific rules regarding subqueries.
-        self.correlate_state = {}
-        
         # for UPDATE and INSERT statements, a set of columns whos values are being set
         # from a SQL expression (i.e., not one of the bind parameter values).  if present,
         # default-value logic in the Dialect knows not to fire off column defaults
         self.string = self.process(self.statement)
         self.after_compile()
     
-    def process(self, obj, **kwargs):
-        return self.traverse_single(obj, **kwargs)
+    def process(self, obj, stack=None, **kwargs):
+        if stack:
+            self.stack.append(stack)
+        try:
+            return self.traverse_single(obj, **kwargs)
+        finally:
+            if stack:
+                self.stack.pop(-1)
         
     def is_subquery(self, select):
-        return self.correlate_state[select].get('is_subquery', False)
+        return self.stack and self.stack[-1].get('is_subquery')
         
     def get_whereclause(self, obj):
         """given a FROM clause, return an additional WHERE condition that should be 
     def visit_label(self, label):
         labelname = self._truncated_identifier("colident", label.name)
         
-        if self.select_stack:
+        if self.stack and self.stack[-1].get('select'):
             self.typemap.setdefault(labelname.lower(), label.obj.type)
             if isinstance(label.obj, sql._ColumnClause):
                 self.column_labels[label.obj._label] = labelname
         else:
             name = column.name
 
-        if self.select_stack:
+        if self.stack and self.stack[-1].get('select'):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
             self.typemap.setdefault(name.lower(), column.type)
         return self.process(clause.clause_expr)
 
     def visit_cast(self, cast, **kwargs):
-        if self.select_stack:
+        if self.stack and self.stack[-1].get('select'):
             # not sure if we want to set the typemap here...
             self.typemap.setdefault("CAST", cast.type)
         return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
 
     def visit_function(self, func, **kwargs):
-        if self.select_stack:
+        if self.stack and self.stack[-1].get('select'):
             self.typemap.setdefault(func.name, func.type)
         if not self.apply_function_parens(func):
             return ".".join(func.packagenames + [func.name])
         else:
             return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
 
-    def visit_compound_select(self, cs, asfrom=False, **kwargs):
-        text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
-        group_by = self.process(cs._group_by_clause)
+    def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
+        text = string.join([self.process(c, asfrom=asfrom, parens=False) for c in cs.selects], " " + cs.keyword + " ")
+        group_by = self.process(cs._group_by_clause, asfrom=asfrom)
         if group_by:
             text += " GROUP BY " + group_by
         text += self.order_by_clause(cs)            
         text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
         
-        if asfrom:
+        if asfrom and parens:
             return "(" + text + ")"
         else:
             return text
         # names look like table.colname. so if column is in a "selected from"
         # subquery, label it synoymously with its column name
         if \
-            self.correlate_state[select].get('is_selected_from', False) and \
+            (self.stack and self.stack[-1].get('is_selected_from')) and \
             isinstance(column, sql._ColumnClause) and \
             not column.is_literal and \
             column.table is not None and \
             return column.label(column.name)
         else:
             return None
-            
-    def visit_select(self, select, asfrom=False, **kwargs):
 
-        select._calculate_correlations(self.correlate_state)
-        self.select_stack.append(select)
+    def visit_select(self, select, asfrom=False, parens=True, **kwargs):
+
+        stack_entry = {'select':select}
+        
+        if asfrom:
+            stack_entry['is_selected_from'] = stack_entry['is_subquery'] = True
+        elif self.stack and self.stack[-1].get('select'):
+            stack_entry['is_subquery'] = True
+
+        if self.stack and self.stack[-1].get('from'):
+            existingfroms = self.stack[-1]['from']
+        else:
+            existingfroms = None
+        froms = select._get_display_froms(existingfroms)
+
+        correlate_froms = util.Set()
+        for f in froms:
+            correlate_froms.add(f)
+            for f2 in f._get_from_objects():
+                correlate_froms.add(f2)
+
+        # TODO: might want to propigate existing froms for select(select(select))
+        # where innermost select should correlate to outermost
+#        if existingfroms:
+#            correlate_froms = correlate_froms.union(existingfroms)    
+        stack_entry['from'] = correlate_froms
+        self.stack.append(stack_entry)
 
         # the actual list of columns to print in the SELECT column list.
         inner_columns = util.OrderedSet()
-        
-        froms = select._get_display_froms(self.correlate_state)
                 
         for co in select.inner_columns:
             if select.use_labels:
                     inner_columns.add(self.process(l))
                 else:
                     inner_columns.add(self.process(co))
-                    
-        self.select_stack.pop(-1)
-
+            
         collist = string.join(inner_columns.difference(util.Set([None])), ', ')
 
         text = " ".join(["SELECT"] + [self.process(x) for x in select._prefixes]) + " "
         text += (select._limit or select._offset) and self.limit_clause(select) or ""
         text += self.for_update_clause(select)
 
-        if asfrom:
+        self.stack.pop(-1)
+
+        if asfrom and parens:
             return "(" + text + ")"
         else:
             return text
          " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
 
     def visit_update(self, update_stmt):
-        update_stmt._calculate_correlations(self.correlate_state)
+        self.stack.append({'from':util.Set([update_stmt.table])})
         
         # search for columns who will be required to have an explicit bound value.
         # for updates, this includes Python-side "onupdate" defaults.
 
         if update_stmt._whereclause:
             text += " WHERE " + self.process(update_stmt._whereclause)
-
+        
+        self.stack.pop(-1)
+        
         return text
 
     def _get_colparams(self, stmt, required_cols):
         return values
 
     def visit_delete(self, delete_stmt):
-        delete_stmt._calculate_correlations(self.correlate_state)
+        self.stack.append({'from':util.Set([delete_stmt.table])})
 
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
         if delete_stmt._whereclause:
             text += " WHERE " + self.process(delete_stmt._whereclause)
 
+        self.stack.pop(-1)
+        
         return text
         
     def visit_savepoint(self, savepoint_stmt):

File lib/sqlalchemy/sql.py

 
         _SelectBaseMixin.__init__(self, **kwargs)
 
-    def _get_display_froms(self, correlation_state=None):
+    def _get_display_froms(self, existing_froms=None):
         """return the full list of 'from' clauses to be displayed.
         
-        takes into account an optional 'correlation_state' 
-        dictionary which contains information about this Select's
-        correlation to an enclosing select, which may cause some 'from'
-        clauses to not display in this Select's FROM clause.  
-        this dictionary is generated during compile time by the 
-        _calculate_correlations() method.  
-        
+        takes into account a set of existing froms which
+        may be rendered in the FROM clause of enclosing selects;
+        this Select may want to leave those absent if it is automatically
+        correlating.
         """
+
         froms = util.OrderedSet()
         hide_froms = util.Set()
         
         
         if len(froms) > 1:
             corr = self.__correlate
-            if correlation_state is not None:
-                corr = correlation_state[self].get('correlate', util.Set()).union(corr)
+            if self._should_correlate and existing_froms is not None:
+                corr = existing_froms.union(corr)
             f = froms.difference(corr)
             if len(f) == 0:
                 raise exceptions.InvalidRequestError("Select statement '%s' is overcorrelated; returned no 'from' clauses" % str(self.__dont_correlate()))
             for f in elem._get_from_objects():
                 froms.add(f)
         return froms
-        
-    def _calculate_correlations(self, correlation_state):
-        """generate a 'correlation_state' dictionary used by the _get_display_froms() method.
-        
-        The dictionary is passed in initially empty, or already 
-        containing the state information added by an enclosing 
-        Select construct.  The method will traverse through all 
-        embedded Select statements and add information about their 
-        position and "from" objects to the dictionary.  Those Select 
-        statements will later consult the 'correlation_state' dictionary 
-        when their list of 'FROM' clauses are generated using their 
-        _get_display_froms() method.
-        """
-        
-        if self not in correlation_state:
-            correlation_state[self] = {}
-
-        display_froms = self._get_display_froms(correlation_state)
-        
-        class CorrelatedVisitor(NoColumnVisitor):
-            def __init__(self, is_where=False, is_column=False, is_from=False):
-                self.is_where = is_where
-                self.is_column = is_column
-                self.is_from = is_from
-                
-            def visit_compound_select(self, cs):
-                self.visit_select(cs)
-
-            def visit_select(s, select):
-                if select not in correlation_state:
-                    correlation_state[select] = {}
-                    
-                if select is self:
-                    return
-                    
-                select_state = correlation_state[select]
-                if s.is_from:
-                    select_state['is_selected_from'] = True
-                if s.is_where:
-                    select_state['is_where'] = True
-                select_state['is_subquery'] = True
-
-                if select._should_correlate:
-                    corr = select_state.setdefault('correlate', util.Set())
-                    # not crazy about this part.  need to be clearer on what elements in the
-                    # subquery correspond to elements in the enclosing query.
-                    for f in display_froms:
-                        corr.add(f)
-                        for f2 in f._get_from_objects():
-                            corr.add(f2)
-        
-        col_vis = CorrelatedVisitor(is_column=True)
-        where_vis = CorrelatedVisitor(is_where=True)
-        from_vis = CorrelatedVisitor(is_from=True)
-    
-        for col in self._raw_columns:
-            col_vis.traverse(col)
-            for f in col._get_from_objects():
-                if f is not self:
-                    from_vis.traverse(f)
-
-        for col in list(self._order_by_clause) + list(self._group_by_clause):
-            col_vis.traverse(col)
-            
-        if self._whereclause is not None:
-            where_vis.traverse(self._whereclause)
-            for f in self._whereclause._get_from_objects(is_where=True):
-                if f is not self:
-                    from_vis.traverse(f)
-                
-        for elem in self._froms:
-            from_vis.traverse(elem)
 
     def _get_inner_columns(self):
         for c in self._raw_columns:
     def supports_execution(self):
         return True
 
-    def _calculate_correlations(self, correlate_state):
-        class SelectCorrelator(NoColumnVisitor):
-            def visit_select(s, select):
-                if select._should_correlate:
-                    select_state = correlate_state.setdefault(select, {})
-                    corr = select_state.setdefault('correlate', util.Set())
-                    corr.add(self.table)
-                    
-        vis = SelectCorrelator()
-        
-        if self._whereclause is not None:
-            vis.traverse(self._whereclause)
-        
-        if getattr(self, 'parameters', None) is not None:
-            for key, value in self.parameters.items():
-                if isinstance(value, ClauseElement):
-                    vis.traverse(value)
-                
     def _process_colparams(self, parameters):
         """Receive the *values* of an ``INSERT`` or ``UPDATE``
         statement and construct appropriate bind parameters.