Commits

Mike Bayer committed d816440

- ANSICompiler now uses its own traversal when compiling, returning a composed
string from each visit_XXXX method, so that the full string is compiled at once
without using any dictionary storage. dialects modified accordingly.
tested on mysql/sqlite/postgres fully,
tested with string-only tests for oracle/fb/informix/mssql so far.

Comments (0)

Files changed (14)

lib/sqlalchemy/ansisql.py

         # actually present in the generated SQL
         self.bind_names = {}
 
-        # a dictionary which stores the string representation for every ClauseElement
-        # processed by this compiler.
-        self.strings = {}
-
-        # a dictionary which stores the string representation for ClauseElements
-        # processed by this compiler, which are to be used in the FROM clause
-        # of a select.  items are often placed in "froms" as well as "strings"
-        # and sometimes with different representations.
-        self.froms = {}
-
-        # slightly hacky.  maps FROM clauses to WHERE clauses, and used in select
-        # generation to modify the WHERE clause of the select.  currently a hack
-        # used by the oracle module.
-        self.wheres = {}
-
         # when the compiler visits a SELECT statement, the clause object is appended
         # to this stack.  various visit operations will check this stack to determine
         # additional choices (TODO: it seems to be all typemap stuff.  shouldnt this only
         # this re will search for params like :param
         # it has a negative lookbehind for an extra ':' so that it doesnt match
         # postgres '::text' tokens
-        text = self.strings[self.statement]
+        text = self.string
         if ':' not in text:
             return
         
                 text = BIND_PARAMS.sub(getnum, text)
         # un-escape any \:params
         text = BIND_PARAMS_ESC.sub(lambda m: m.group(1), text)
-        self.strings[self.statement] = text
+        self.string = text
 
+    def compile(self):
+        self.string = self.process(self.statement)
+        self.after_compile()
+    
+    def process(self, obj, **kwargs):
+        return self.traverse_single(obj, **kwargs)
+        
     def is_subquery(self, select):
         return self.correlate_state[select].get('is_subquery', False)
         
     def get_whereclause(self, obj):
-        return self.wheres.get(obj, None)
+        """given a FROM clause, return an additional WHERE condition that should be 
+        applied to a SELECT. 
+        
+        Currently used by Oracle to provide WHERE criterion for JOIN and OUTER JOIN
+        constructs in non-ansi mode.
+        """
+        
+        return None
 
     def construct_params(self, params):
         """Return a sql.ClauseParameters object.
 
         return ""
     
-    def visit_grouping(self, grouping):
-        self.strings[grouping] = self.froms[grouping] = "(" + self.strings[grouping.elem] + ")"
+    def visit_grouping(self, grouping, **kwargs):
+        return "(" + self.process(grouping.elem) + ")"
         
     def visit_label(self, label):
         labelname = self._truncated_identifier("colident", label.name)
             if isinstance(label.obj, sql._ColumnClause):
                 self.column_labels[label.obj._label] = labelname
             self.column_labels[label.name] = labelname
-        self.strings[label] = " ".join([self.strings[label.obj], self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
+        return " ".join([self.process(label.obj), self.operator_string(sql.ColumnOperators.as_), self.preparer.format_label(label, labelname)])
         
-    def visit_column(self, column):
+    def visit_column(self, column, **kwargs):
         # there is actually somewhat of a ruleset when you would *not* necessarily
         # want to truncate a column identifier, if its mapped to the name of a 
         # physical column.  but thats very hard to identify at this point, and 
         else:
             name = column.name
 
-        if column.table is None or not column.table.named_with_column():
-            self.strings[column] = self.preparer.format_column(column, name=name)
-        else:
-            if column.table.oid_column is column:
-                n = self.dialect.oid_column_name(column)
-                if n is not None:
-                    self.strings[column] = "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
-                elif len(column.table.primary_key) != 0:
-                    pk = list(column.table.primary_key)[0]
-                    pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
-                    self.strings[column] = self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
-                else:
-                    self.strings[column] = None
-            else:
-                self.strings[column] = self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
-
         if len(self.select_stack):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
             self.typemap.setdefault(name.lower(), column.type)
             self.column_labels.setdefault(column._label, name.lower())
 
-    def visit_fromclause(self, fromclause):
-        self.froms[fromclause] = fromclause.name
+        if column.table is None or not column.table.named_with_column():
+            return self.preparer.format_column(column, name=name)
+        else:
+            if column.table.oid_column is column:
+                n = self.dialect.oid_column_name(column)
+                if n is not None:
+                    return "%s.%s" % (self.preparer.format_table(column.table, use_schema=False, name=self._anonymize(column.table.name)), n)
+                elif len(column.table.primary_key) != 0:
+                    pk = list(column.table.primary_key)[0]
+                    pkname = (pk.is_literal and name or self._truncated_identifier("colident", pk.name))
+                    return self.preparer.format_column_with_table(list(column.table.primary_key)[0], column_name=pkname, table_name=self._anonymize(column.table.name))
+                else:
+                    return None
+            else:
+                return self.preparer.format_column_with_table(column, column_name=name, table_name=self._anonymize(column.table.name))
 
-    def visit_index(self, index):
-        self.strings[index] = index.name
 
-    def visit_typeclause(self, typeclause):
-        self.strings[typeclause] = typeclause.type.dialect_impl(self.dialect).get_col_spec()
+    def visit_fromclause(self, fromclause, **kwargs):
+        return fromclause.name
 
-    def visit_textclause(self, textclause):
-        self.strings[textclause] = textclause.text
-        self.froms[textclause] = textclause.text
+    def visit_index(self, index, **kwargs):
+        return index.name
+
+    def visit_typeclause(self, typeclause, **kwargs):
+        return typeclause.type.dialect_impl(self.dialect).get_col_spec()
+
+    def visit_textclause(self, textclause, **kwargs):
+        for bind in textclause.bindparams.values():
+            self.process(bind)
         if textclause.typemap is not None:
             self.typemap.update(textclause.typemap)
+        return textclause.text
 
-    def visit_null(self, null):
-        self.strings[null] = 'NULL'
+    def visit_null(self, null, **kwargs):
+        return 'NULL'
 
-    def visit_clauselist(self, clauselist):
+    def visit_clauselist(self, clauselist, **kwargs):
         sep = clauselist.operator
         if sep is None:
             sep = " "
             sep = ', '
         else:
             sep = " " + self.operator_string(clauselist.operator) + " "
-        self.strings[clauselist] = string.join([s for s in [self.strings[c] for c in clauselist.clauses] if s is not None], sep)
+        return string.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None], sep)
 
     def apply_function_parens(self, func):
         return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
 
-    def visit_calculatedclause(self, clause):
-        self.strings[clause] = self.strings[clause.clause_expr]
+    def visit_calculatedclause(self, clause, **kwargs):
+        return self.process(clause.clause_expr)
 
-    def visit_cast(self, cast):
+    def visit_cast(self, cast, **kwargs):
         if len(self.select_stack):
             # not sure if we want to set the typemap here...
             self.typemap.setdefault("CAST", cast.type)
-        self.strings[cast] = "CAST(%s AS %s)" % (self.strings[cast.clause],self.strings[cast.typeclause])
+        return "CAST(%s AS %s)" % (self.process(cast.clause), self.process(cast.typeclause))
 
-    def visit_function(self, func):
+    def visit_function(self, func, **kwargs):
         if len(self.select_stack):
             self.typemap.setdefault(func.name, func.type)
         if not self.apply_function_parens(func):
-            self.strings[func] = ".".join(func.packagenames + [func.name])
-            self.froms[func] = self.strings[func]
+            return ".".join(func.packagenames + [func.name])
         else:
-            self.strings[func] = ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.strings[func.clause_expr]
-            self.froms[func] = self.strings[func]
+            return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
 
-    def visit_compound_select(self, cs):
-        text = string.join([self.strings[c] for c in cs.selects], " " + cs.keyword + " ")
-        group_by = self.strings[cs._group_by_clause]
+    def visit_compound_select(self, cs, asfrom=False, **kwargs):
+        text = string.join([self.process(c) for c in cs.selects], " " + cs.keyword + " ")
+        group_by = self.process(cs._group_by_clause)
         if group_by:
             text += " GROUP BY " + group_by
         text += self.order_by_clause(cs)            
-        text += self.visit_select_postclauses(cs)
-        self.strings[cs] = text
-        self.froms[cs] = "(" + text + ")"
+        text += (cs._limit or cs._offset) and self.limit_clause(cs) or ""
+        
+        if asfrom:
+            return "(" + text + ")"
+        else:
+            return text
 
-    def visit_unary(self, unary):
-        s = self.strings[unary.element]
+    def visit_unary(self, unary, **kwargs):
+        s = self.process(unary.element)
         if unary.operator:
             s = self.operator_string(unary.operator) + " " + s
         if unary.modifier:
             s = s + " " + unary.modifier
-        self.strings[unary] = s
+        return s
         
-    def visit_binary(self, binary):
+    def visit_binary(self, binary, **kwargs):
         op = self.operator_string(binary.operator)
         if callable(op):
-            self.strings[binary] = op(binary.left, binary.right)
+            return op(self.process(binary.left), self.process(binary.right))
         else:
-            self.strings[binary] = self.strings[binary.left] + " " + op + " " + self.strings[binary.right]
+            return self.process(binary.left) + " " + op + " " + self.process(binary.right)
         
     def operator_string(self, operator):
         return self.operators.get(operator, str(operator))
 
-    def visit_bindparam(self, bindparam):
+    def visit_bindparam(self, bindparam, **kwargs):
         # apply truncation to the ultimate generated name
 
         if bindparam.shortname != bindparam.key:
                 key = bindparam.key + tag
                 count += 1
             bindparam.key = key
-            self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
+            return self.bindparam_string(self._truncate_bindparam(bindparam))
         else:
             existing = self.binds.get(bindparam.key)
             if existing is not None and existing.unique:
                 raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key)
-            self.strings[bindparam] = self.bindparam_string(self._truncate_bindparam(bindparam))
             self.binds[bindparam.key] = bindparam
+            return self.bindparam_string(self._truncate_bindparam(bindparam))
     
     def _truncate_bindparam(self, bindparam):
         if bindparam in self.bind_names:
     def bindparam_string(self, name):
         return self.bindtemplate % name
 
-    def visit_alias(self, alias):
-        self.froms[alias] = self.froms[alias.original] + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
-        self.strings[alias] = self.strings[alias.original]
+    def visit_alias(self, alias, asfrom=False, **kwargs):
+        if asfrom:
+            return self.process(alias.original, asfrom=True, **kwargs) + " AS " + self.preparer.format_alias(alias, self._anonymize(alias.name))
+        else:
+            return self.process(alias.original, **kwargs)
 
-    def enter_select(self, select):
-        select._calculate_correlations(self.correlate_state)
-        self.select_stack.append(select)
-    
-    def enter_update(self, update):
-        update._calculate_correlations(self.correlate_state)
-
-    def enter_delete(self, delete):
-        delete._calculate_correlations(self.correlate_state)
-    
     def label_select_column(self, select, column):
         """convert a column from a select's "columns" clause.
         
         else:
             return None
             
-    def visit_select(self, select):
+    def visit_select(self, select, asfrom=False, **kwargs):
+
+        select._calculate_correlations(self.correlate_state)
+        self.select_stack.append(select)
+
         # the actual list of columns to print in the SELECT column list.
-        inner_columns = util.OrderedDict()
+        inner_columns = util.OrderedSet()
         
         froms = select._get_display_froms(self.correlate_state)
-        for f in froms:
-            if f not in self.strings:
-                self.traverse(f)
                 
         for co in select.inner_columns:
             if select.use_labels:
                 labelname = co._label
                 if labelname is not None:
                     l = co.label(labelname)
-                    self.traverse(l)
-                    inner_columns[labelname] = l
+                    inner_columns.add(self.process(l))
                 else:
                     self.traverse(co)
-                    inner_columns[self.strings[co]] = co
+                    inner_columns.add(self.process(co))
             else:
                 l = self.label_select_column(select, co)
                 if l is not None:
-                    self.traverse(l)
-                    inner_columns[self.strings[l.obj]] = l
+                    inner_columns.add(self.process(l))
                 else:
-                    self.traverse(co)
-                    inner_columns[self.strings[co]] = co
+                    inner_columns.add(self.process(co))
                     
         self.select_stack.pop(-1)
 
-        collist = string.join([self.strings[v] for v in inner_columns.values()], ', ')
+        collist = string.join(inner_columns.difference(util.Set([None])), ', ')
 
         text = "SELECT "
-        text += self.visit_select_precolumns(select)
+        text += self.get_select_precolumns(select)
         text += collist
 
         whereclause = select._whereclause
 
         from_strings = []
         for f in froms:
-            # special thingy used by oracle to redefine a join
+            from_strings.append(self.process(f, asfrom=True))
+
             w = self.get_whereclause(f)
             if w is not None:
-                # TODO: move this more into the oracle module
                 if whereclause is not None:
-                    whereclause = self.traverse(sql.and_(w, whereclause), stop_on=util.Set([whereclause, w]))
+                    whereclause = sql.and_(w, whereclause)
                 else:
                     whereclause = w
 
-            from_strings.append(self.froms[f])
-
         if len(froms):
             text += " \nFROM "
             text += string.join(from_strings, ', ')
             text += self.default_from()
 
         if whereclause is not None:
-            t = self.strings[whereclause]
+            t = self.process(whereclause)
             if t:
                 text += " \nWHERE " + t
 
-        group_by = self.strings[select._group_by_clause]
+        group_by = self.process(select._group_by_clause)
         if group_by:
             text += " GROUP BY " + group_by
 
         if select._having is not None:
-            t = self.strings[select._having]
+            t = self.process(select._having)
             if t:
                 text += " \nHAVING " + t
 
         text += self.order_by_clause(select)
-        text += self.visit_select_postclauses(select)
+        text += (select._limit or select._offset) and self.limit_clause(select) or ""
         text += self.for_update_clause(select)
 
-        self.strings[select] = text
-        self.froms[select] = "(" + text + ")"
+        if asfrom:
+            return "(" + text + ")"
+        else:
+            return text
 
-    def visit_select_precolumns(self, select):
+    def get_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just before column list."""
-
         return select._distinct and "DISTINCT " or ""
 
-    def visit_select_postclauses(self, select):
-        """Called when building a ``SELECT`` statement, position is after all other ``SELECT`` clauses.
-
-        Most DB syntaxes put ``LIMIT``/``OFFSET`` here.
-        """
-
-        return (select._limit or select._offset) and self.limit_clause(select) or ""
-
     def order_by_clause(self, select):
-        order_by = self.strings[select._order_by_clause]
+        order_by = self.process(select._order_by_clause)
         if order_by:
             return " ORDER BY " + order_by
         else:
             text += " OFFSET " + str(select._offset)
         return text
 
-    def visit_table(self, table):
-        self.froms[table] = self.preparer.format_table(table)
-        self.strings[table] = ""
+    def visit_table(self, table, asfrom=False, **kwargs):
+        if asfrom:
+            return self.preparer.format_table(table)
+        else:
+            return ""
 
-    def visit_join(self, join):
-        righttext = self.froms[join.right]
-        if join.isouter:
-            self.froms[join] = (self.froms[join.left] + " LEFT OUTER JOIN " + righttext +
-            " ON " + self.strings[join.onclause])
-        else:
-            self.froms[join] = (self.froms[join.left] + " JOIN " + righttext +
-            " ON " + self.strings[join.onclause])
-        self.strings[join] = self.froms[join]
+    def visit_join(self, join, asfrom=False, **kwargs):
+        return (self.process(join.left, asfrom=True) + (join.isouter and " LEFT OUTER JOIN " or " JOIN ") + \
+            self.process(join.right, asfrom=True) + " ON " + self.process(join.onclause))
 
     def uses_sequences_for_inserts(self):
         return False
         self.isinsert = True
         colparams = self._get_colparams(insert_stmt, required_cols)
 
-        text = ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
+        return ("INSERT INTO " + self.preparer.format_table(insert_stmt.table) + " (" + string.join([self.preparer.format_column(c[0]) for c in colparams], ', ') + ")" +
          " VALUES (" + string.join([c[1] for c in colparams], ', ') + ")")
 
-        self.strings[insert_stmt] = text
-
     def visit_update(self, update_stmt):
+        update_stmt._calculate_correlations(self.correlate_state)
         
         # search for columns who will be required to have an explicit bound value.
         # for updates, this includes Python-side "onupdate" defaults.
         text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.format_column(c[0]), c[1]) for c in colparams], ', ')
 
         if update_stmt._whereclause:
-            text += " WHERE " + self.strings[update_stmt._whereclause]
+            text += " WHERE " + self.process(update_stmt._whereclause)
 
-        self.strings[update_stmt] = text
+        return text
 
     def _get_colparams(self, stmt, required_cols):
         """create a set of tuples representing column/string pairs for use 
         def create_clause_param(col, value):
             self.traverse(value)
             self.inline_params.add(col)
-            return self.strings[value]
+            return self.process(value)
 
         self.inline_params = util.Set()
 
         return values
 
     def visit_delete(self, delete_stmt):
+        delete_stmt._calculate_correlations(self.correlate_state)
+
         text = "DELETE FROM " + self.preparer.format_table(delete_stmt.table)
 
         if delete_stmt._whereclause:
-            text += " WHERE " + self.strings[delete_stmt._whereclause]
+            text += " WHERE " + self.process(delete_stmt._whereclause)
 
-        self.strings[delete_stmt] = text
+        return text
         
     def visit_savepoint(self, savepoint_stmt):
-        text = "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
-        self.strings[savepoint_stmt] = text
+        return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
 
     def visit_rollback_to_savepoint(self, savepoint_stmt):
-        text = "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
-        self.strings[savepoint_stmt] = text
+        return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
     
     def visit_release_savepoint(self, savepoint_stmt):
-        text = "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
-        self.strings[savepoint_stmt] = text
+        return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt.ident)
     
     def __str__(self):
-        return self.strings[self.statement]
+        return self.string
 
 class ANSISchemaBase(engine.SchemaIterator):
     def find_alterables(self, tables):

lib/sqlalchemy/databases/firebird.py

 class FBCompiler(ansisql.ANSICompiler):
     """Firebird specific idiosincrasies"""
 
-    def visit_alias(self, alias):
+    def visit_alias(self, alias, asfrom=False, **kwargs):
         # Override to not use the AS keyword which FB 1.5 does not like
-        self.froms[alias] = self.froms[alias.original] + " " + self.preparer.format_alias(alias)
-        self.strings[alias] = self.strings[alias.original]
+        if asfrom:
+            return self.process(alias.original, asfrom=True) + " " + self.preparer.format_alias(alias)
+        else:
+            return self.process(alias.original, asfrom=True)
 
     def visit_function(self, func):
         if len(func.clauses):
-            super(FBCompiler, self).visit_function(func)
+            return super(FBCompiler, self).visit_function(func)
         else:
-            self.strings[func] = func.name
+            return func.name
 
-    def visit_insert_column(self, column, parameters):
-        # all column primary key inserts must be explicitly present
-        if column.primary_key:
-            parameters[column.key] = None
+    def uses_sequences_for_inserts(self):
+        return True
 
-    def visit_select_precolumns(self, select):
+    def get_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just
         before column list Firebird puts the limit and offset right
         after the ``SELECT``...
         return result
 
     def limit_clause(self, select):
-        """Already taken care of in the `visit_select_precolumns` method."""
+        """Already taken care of in the `get_select_precolumns` method."""
         return ""
 
 

lib/sqlalchemy/databases/informix.py

     def default_from(self):
         return " from systables where tabname = 'systables' "
     
-    def visit_select_precolumns( self , select ):
+    def get_select_precolumns( self , select ):
         s = select._distinct and "DISTINCT " or ""
         # only has limit
         if select._limit:
                 return c._label.lower()
             except:
                 return ''
-                
+        
+        # TODO: dont modify the original select, generate a new one        
         a = [ __label(c) for c in select._raw_columns ]
         for c in select.order_by_clause.clauses:
             if ( __label(c) not in a ) and getattr( c , 'name' , '' ) != 'oid':
                 select.append_column( c )
         
-        ansisql.ANSICompiler.visit_select(self, select)
+        return ansisql.ANSICompiler.visit_select(self, select)
         
     def limit_clause(self, select):
         return ""
 
     def visit_function( self , func ):
         if func.name.lower() == 'current_date':
-            self.strings[func] = "today"
+            return "today"
         elif func.name.lower() == 'current_time':
-            self.strings[func] = "CURRENT HOUR TO SECOND"
+            return "CURRENT HOUR TO SECOND"
         elif func.name.lower() in ( 'current_timestamp' , 'now' ):
-            self.strings[func] = "CURRENT YEAR TO SECOND"
+            return "CURRENT YEAR TO SECOND"
         else:
-            ansisql.ANSICompiler.visit_function( self , func )
+            return ansisql.ANSICompiler.visit_function( self , func )
             
     def visit_clauselist(self, list):
         try:
             li = [ c for c in list.clauses if c.name != 'oid' ]
         except:
             li = [ c for c in list.clauses ]
-        if list.parens:
-            self.strings[list] = "(" + ', '.join([s for s in [self.strings[c] for c in li] if s is not None ]) + ")"
-        else:
-            self.strings[list] = ', '.join([s for s in [self.strings[c] for c in li] if s is not None])
+        return ', '.join([s for s in [self.process(c) for c in li] if s is not None])
 
 class InfoSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column, first_pk=False):

lib/sqlalchemy/databases/mssql.py

         super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
         self.tablealiases = {}
 
-    def visit_select_precolumns(self, select):
+    def get_select_precolumns(self, select):
         """ MS-SQL puts TOP, it's version of LIMIT here """
         s = select._distinct and "DISTINCT " or ""
         if select._limit:
         # Limit in mssql is after the select keyword
         return ""
             
-    def visit_table(self, table):
+    def _schema_aliased_table(self, table):
+        if getattr(table, 'schema', None) is not None:
+            if not self.tablealiases.has_key(table):
+                self.tablealiases[table] = table.alias()
+            return self.tablealiases[table]
+        else:
+            return None
+            
+    def visit_table(self, table, mssql_aliased=False, **kwargs):
+        if mssql_aliased:
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
         # alias schema-qualified tables
-        if getattr(table, 'schema', None) is not None and not self.tablealiases.has_key(table):
-            alias = table.alias()
-            self.tablealiases[table] = alias
-            self.traverse(alias)
-            self.froms[('alias', table)] = self.froms[table]
-            for c in alias.c:
-                self.traverse(c)
-            self.traverse(alias.oid_column)
-            self.tablealiases[alias] = self.froms[table]
-            self.froms[table] = self.froms[alias]
+        alias = self._schema_aliased_table(table)
+        if alias is not None:
+            return self.process(alias, mssql_aliased=True, **kwargs)
         else:
-           super(MSSQLCompiler, self).visit_table(table)
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
  
-    def visit_alias(self, alias):
+    def visit_alias(self, alias, **kwargs):
         # translate for schema-qualified table aliases
-        if self.froms.has_key(('alias', alias.original)):
-            self.froms[alias] = self.froms[('alias', alias.original)] + " AS " + alias.name
-            self.strings[alias] = ""
-        else:
-            super(MSSQLCompiler, self).visit_alias(alias)
+        self.tablealiases[alias.original] = alias
+        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
 
     def visit_column(self, column):
-        # translate for schema-qualified table aliases
-        super(MSSQLCompiler, self).visit_column(column)
-        if column.table is not None and self.tablealiases.has_key(column.table):
-            self.strings[column] = \
-                self.strings[self.tablealiases[column.table].corresponding_column(column)]
+        if column.table is not None:
+            # translate for schema-qualified table aliases
+            t = self._schema_aliased_table(column.table)
+            if t is not None:
+                return self.process(t.corresponding_column(column))
+        return super(MSSQLCompiler, self).visit_column(column)
 
     def visit_binary(self, binary):
         """Move bind parameters to the right-hand side of an operator, where possible."""
-        if isinstance(binary.left, sql._BindParamClause) and binary.operator == '=':
-            binary.left, binary.right = binary.right, binary.left
-        super(MSSQLCompiler, self).visit_binary(binary)
+        if isinstance(binary.left, sql._BindParamClause) and binary.operator == operator.eq:
+            return self.process(sql._BinaryExpression(binary.right, binary.left, binary.operator))
+        else:
+            return super(MSSQLCompiler, self).visit_binary(binary)
 
     def label_select_column(self, select, column):
         if isinstance(column, sql._Function):
         return ''
 
     def order_by_clause(self, select):
-        order_by = self.strings[select._order_by_clause]
+        order_by = self.process(select._order_by_clause)
 
         # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
         if order_by and (not self.is_subquery(select) or select._limit):

lib/sqlalchemy/databases/mysql.py

         }
     )
 
-    def visit_cast(self, cast):
+    def visit_cast(self, cast, **kwargs):
         if isinstance(cast.type, (sqltypes.Date, sqltypes.Time, sqltypes.DateTime)):
-            return super(MySQLCompiler, self).visit_cast(cast)
+            return super(MySQLCompiler, self).visit_cast(cast, **kwargs)
         else:
             # so just skip the CAST altogether for now.
             # TODO: put whatever MySQL does for CAST here.
-            self.strings[cast] = self.strings[cast.clause]
+            return self.process(cast.clause)
 
     def for_update_clause(self, select):
         if select.for_update == 'read':

lib/sqlalchemy/databases/oracle.py

 
 OracleDialect.logger = logging.class_logger(OracleDialect)
 
+class _OuterJoinColumn(sql.ClauseElement):
+    __visit_name__ = 'outer_join_column'
+    def __init__(self, column):
+        self.column = column
+        
 class OracleCompiler(ansisql.ANSICompiler):
     """Oracle compiler modifies the lexical structure of Select
     statements to work under non-ANSI configured Oracle databases, if
         }
     )
 
+    def __init__(self, *args, **kwargs):
+        super(OracleCompiler, self).__init__(*args, **kwargs)
+        self.__wheres = {}
+        
     def default_from(self):
         """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended.
 
     def apply_function_parens(self, func):
         return len(func.clauses) > 0
 
-    def visit_join(self, join):
+    def visit_join(self, join, **kwargs):
         if self.dialect.use_ansi:
-            return ansisql.ANSICompiler.visit_join(self, join)
+            return ansisql.ANSICompiler.visit_join(self, join, **kwargs)
 
-        self.froms[join] = self.froms[join.left] + ", " + self.froms[join.right]
-        where = self.wheres.get(join.left, None)
+        (where, parentjoin) = self.__wheres.get(join, (None, None))
+
+        class VisitOn(sql.ClauseVisitor):
+            def visit_binary(s, binary):
+                if binary.operator == operator.eq:
+                    if binary.left.table is join.right:
+                        binary.left = _OuterJoinColumn(binary.left)
+                    elif binary.right.table is join.right:
+                        binary.right = _OuterJoinColumn(binary.right)
+                        
         if where is not None:
-            self.wheres[join] = sql.and_(where, join.onclause)
+            self.__wheres[join.left] = self.__wheres[parentjoin] = (sql.and_(VisitOn().traverse(join.onclause, clone=True), where), parentjoin)
         else:
-            self.wheres[join] = join.onclause
-#        self.wheres[join] = sql.and_(self.wheres.get(join.left, None), join.onclause)
-        self.strings[join] = self.froms[join]
+            self.__wheres[join.left] = self.__wheres[join] = (VisitOn().traverse(join.onclause, clone=True), join)
 
-        if join.isouter:
-            # if outer join, push on the right side table as the current "outertable"
-            self._outertable = join.right
-
-            # now re-visit the onclause, which will be used as a where clause
-            # (the first visit occured via the Join object itself right before it called visit_join())
-            self.traverse(join.onclause)
-
-            self._outertable = None
-
-        self.traverse_single(self.wheres[join])
-
+        return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True)
+    
+    def get_whereclause(self, f):
+        if f in self.__wheres:
+            return self.__wheres[f][0]
+        else:
+            return None
+            
+    def visit_outer_join_column(self, vc):
+        return self.process(vc.column) + "(+)"
     def uses_sequences_for_inserts(self):
         return True
 
-    def visit_alias(self, alias):
+    def visit_alias(self, alias, asfrom=False, **kwargs):
         """Oracle doesn't like ``FROM table AS alias``.  Is the AS standard SQL??"""
-
-        self.froms[alias] = self.froms[alias.original] + " " + alias.name
-        self.strings[alias] = self.strings[alias.original]
-
-    def visit_column(self, column):
-        ansisql.ANSICompiler.visit_column(self, column)
-        if not self.dialect.use_ansi and getattr(self, '_outertable', None) is not None and column.table is self._outertable:
-            self.strings[column] = self.strings[column] + "(+)"
+        
+        if asfrom:
+            return self.process(alias.original) + " " + alias.name
+        else:
+            return self.process(alias.original)
 
     def visit_insert(self, insert):
         """``INSERT`` s are required to have the primary keys be explicitly present.
         """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle."""
         pass
 
-    def visit_select(self, select):
+    def visit_select(self, select, **kwargs):
         """Look for ``LIMIT`` and OFFSET in a select statement, and if
         so tries to wrap it in a subquery with ``row_number()`` criterion.
         """
 
         if not getattr(select, '_oracle_visit', None) and (select._limit is not None or select._offset is not None):
             # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.strings[select._order_by_clause]
+            orderby = self.process(select._order_by_clause)
             if not orderby:
                 orderby = select.oid_column
                 self.traverse(orderby)
-                orderby = self.strings[orderby]
+                orderby = self.process(orderby)
                 
             oldselect = select
             select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None)
                     limitselect.append_whereclause("ora_rn<=%d" % (select._limit + select._offset))
             else:
                 limitselect.append_whereclause("ora_rn<=%d" % select._limit)
-            self.traverse(limitselect)
-            self.strings[oldselect] = self.strings[limitselect]
-            self.froms[oldselect] = self.froms[limitselect]
+            return self.process(limitselect)
         else:
-            ansisql.ANSICompiler.visit_select(self, select)
+            return ansisql.ANSICompiler.visit_select(self, select, **kwargs)
 
     def limit_clause(self, select):
         return ""

lib/sqlalchemy/databases/postgres.py

             text += " OFFSET " + str(select._offset)
         return text
 
-    def visit_select_precolumns(self, select):
+    def get_select_precolumns(self, select):
         if select._distinct:
             if type(select._distinct) == bool:
                 return "DISTINCT "

lib/sqlalchemy/databases/sqlite.py

 class SQLiteCompiler(ansisql.ANSICompiler):
     def visit_cast(self, cast):
         if self.dialect.supports_cast:
-            super(SQLiteCompiler, self).visit_cast(cast)
+            return super(SQLiteCompiler, self).visit_cast(cast)
         else:
             if len(self.select_stack):
                 # not sure if we want to set the typemap here...
                 self.typemap.setdefault("CAST", cast.type)
-            self.strings[cast] = self.strings[cast.clause]
+            return self.process(cast.clause)
 
     def limit_clause(self, select):
         text = ""

lib/sqlalchemy/engine/base.py

         self.bind = bind
         self.can_execute = statement.supports_execution()
 
-    def compile(self):
-        self.traverse(self.statement)
-        self.after_compile()
-
     def __str__(self):
         """Return the string text of the generated SQL statement."""
 

lib/sqlalchemy/sql.py

     """
     __traverse_options__ = {}
     
-    def traverse_single(self, obj):
+    def traverse_single(self, obj, **kwargs):
         meth = getattr(self, "visit_%s" % obj.__visit_name__, None)
         if meth:
-            return meth(obj)
+            return meth(obj, **kwargs)
             
     def traverse(self, obj, stop_on=None, clone=False):
         if clone:
             obj = obj._clone()
 
-        # entry flag indicates to also call a before-descent "enter_XXXX" method
-        entry = self.__traverse_options__.get('entry', False)
-
         v = self
         visitors = []
         while v is not None:
         def _trav(obj):
             if stop_on is not None and obj in stop_on:
                 return
-            if entry:
-                for v in visitors:
-                    meth = getattr(v, "enter_%s" % obj.__visit_name__, None)
-                    if meth:
-                        meth(obj)
-
             if clone:
                 obj._copy_internals()
             for c in obj.get_children(**self.__traverse_options__):

test/orm/sharding/alltests.py

 import testbase
 import unittest
 
-import orm.inheritance.alltests as inheritance
-
 def suite():
     modules_to_test = (
         'orm.sharding.shard',
         for token in name.split('.')[1:]:
             mod = getattr(mod, token)
         alltests.addTest(unittest.findTestCases(mod, suiteClass=None))
-    alltests.addTest(inheritance.suite())
     return alltests
 
 

test/sql/labels.py

         x = select([tt], use_labels=True, order_by=tt.oid_column).compile(dialect=dialect)
         #print x
         # assert it doesnt end with "ORDER BY foo.some_large_named_table_this_is_the_primarykey_column"
-        assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_1""")
+        assert str(x).endswith("""ORDER BY foo.some_large_named_table_t_2""")
 
 if __name__ == '__main__':
     testbase.main()

test/sql/select.py

         crit = q.c.myid == table1.c.myid
         self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable ORDER BY mytable.myid) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=sqlite.dialect())
         self.runtest(select(['*'], crit), """SELECT * FROM (SELECT mytable.myid AS myid FROM mytable) AS foo, mytable WHERE foo.myid = mytable.myid""", dialect=mssql.dialect())
+
+    def testmssql_aliases_schemas(self):
+        self.runtest(table4.select(), "SELECT remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM remote_owner.remotetable")
+
+        dialect = mssql.dialect()
+        self.runtest(table4.select(), "SELECT remotetable_1.rem_id, remotetable_1.datatype_id, remotetable_1.value FROM remote_owner.remotetable AS remotetable_1", dialect=dialect)
+
+        # TODO: this is probably incorrect; no "AS <foo>" is being applied to the table
+        self.runtest(table1.join(table4, table1.c.myid==table4.c.rem_id).select(), "SELECT mytable.myid, mytable.name, mytable.description, remotetable.rem_id, remotetable.datatype_id, remotetable.value FROM mytable JOIN remote_owner.remotetable ON remotetable.rem_id = mytable.myid")
         
     def testdontovercorrelate(self):
         self.runtest(select([table1], from_obj=[table1, table1.select()]), """SELECT mytable.myid, mytable.name, mytable.description FROM mytable, (SELECT mytable.myid AS myid, mytable.name AS name, mytable.description AS description FROM mytable)""")
                          order_by = ['dist', places.c.nm]
                          )
 
-        self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE zips.zipcode = :zips_zipcode_1), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_2)) AS dist FROM places, zips WHERE zips.zipcode = :zips_zipcode ORDER BY dist, places.nm")
+        self.runtest(q,"SELECT places.id, places.nm, zips.zipcode, latlondist((SELECT zips.latitude FROM zips WHERE "
+        "zips.zipcode = :zips_zipcode), (SELECT zips.longitude FROM zips WHERE zips.zipcode = :zips_zipcode_1)) AS dist "
+        "FROM places, zips WHERE zips.zipcode = :zips_zipcode_2 ORDER BY dist, places.nm")
         
         zalias = zips.alias('main_zip')
         qlat = select([zips.c.latitude], zips.c.zipcode == zalias.c.zipcode, scalar=True)
             dialect=postgres.dialect()
             )
 
+
         self.runtest(query, 
             "SELECT mytable.myid, mytable.name, mytable.description, myothertable.otherid, myothertable.othername \
 FROM mytable, myothertable WHERE mytable.myid = myothertable.otherid(+) AND \
             values = {
             table1.c.name : table1.c.name + "lala",
             table1.c.myid : func.do_stuff(table1.c.myid, literal('hoho'))
-            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal_2), name=(mytable.name || :mytable_name) WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal || mytable.name || :literal_1")
+            }), "UPDATE mytable SET myid=do_stuff(mytable.myid, :literal), name=(mytable.name || :mytable_name) "
+            "WHERE mytable.myid = hoho(:hoho) AND mytable.name = :literal_1 || mytable.name || :literal_2")
         
     def testcorrelatedupdate(self):
         # test against a straight text subquery

test/zblog/tables.py

     Column('user_id', Integer, primary_key=True),
     Column('user_name', String(30), nullable=False),
     Column('fullname', String(100), nullable=False),
-    Column('password', String(30), nullable=False),
+    Column('password', String(40), nullable=False),
     Column('groupname', String(20), nullable=False),
     )