Commits

Mike Bayer committed 3bbfd79

- [feature] Added cte() method to Query,
invokes common table expression support
from the Core (see below). [ticket:1859]

- [feature] Added support for SQL standard
common table expressions (CTE), allowing
SELECT objects as the CTE source (DML
not yet supported). This is invoked via
the cte() method on any select() construct.
[ticket:1859]

Comments (0)

Files changed (7)

     manager to Session, used with with:
     will temporarily disable autoflush.
 
+  - [feature] Added cte() method to Query,
+    invokes common table expression support
+    from the Core (see below). [ticket:1859]
+
   - [bug] Fixed bug whereby MappedCollection
     would not get the appropriate collection
     instrumentation if it were only used
     on the method object.  [ticket:2352]
 
 - sql
+  - [feature] Added support for SQL standard
+    common table expressions (CTE), allowing
+    SELECT objects as the CTE source (DML
+    not yet supported).  This is invoked via
+    the cte() method on any select() construct.
+    [ticket:1859]
+
   - [bug] Added support for using the .key
     of a Column as a string identifier in a 
     result set row.   The .key is currently

doc/build/core/expression_api.rst

    :members:
    :show-inheritance:
 
+.. autoclass:: CTE
+   :members:
+   :show-inheritance:
+
 .. autoclass:: Delete
    :members: where
    :show-inheritance:

lib/sqlalchemy/dialects/mssql/base.py

         ]
         return 'OUTPUT ' + ', '.join(columns)
 
+    def get_cte_preamble(self, recursive):
+        # SQL Server finds it too inconvenient to accept
+        # an entirely optional, SQL standard specified,
+        # "RECURSIVE" word with their "WITH",
+        # so here we go
+        return "WITH"
+
     def label_select_column(self, select, column, asfrom):
         if isinstance(column, expression.Function):
             return column.label(None)

lib/sqlalchemy/orm/query.py

         """
         return self.enable_eagerloads(False).statement.alias(name=name)
 
+    def cte(self, name=None, recursive=False):
+        """Return the full SELECT statement represented by this :class:`.Query`
+        represented as a common table expression (CTE).
+
+        The :meth:`.Query.cte` method is new in 0.7.6.
+        
+        Parameters and usage are the same as those of the 
+        :meth:`._SelectBase.cte` method; see that method for 
+        further details.
+        
+        Here is the `Postgresql WITH 
+        RECURSIVE example <http://www.postgresql.org/docs/8.4/static/queries-with.html>`_.
+        Note that, in this example, the ``included_parts`` cte and the ``incl_alias`` alias
+        of it are Core selectables, which
+        means the columns are accessed via the ``.c.`` attribute.  The ``parts_alias``
+        object is an :func:`.orm.aliased` instance of the ``Part`` entity, so column-mapped
+        attributes are available directly::
+
+            from sqlalchemy.orm import aliased
+
+            class Part(Base):
+                __tablename__ = 'part'
+                part = Column(String)
+                sub_part = Column(String)
+                quantity = Column(Integer)
+
+            included_parts = session.query(
+                                Part.sub_part, 
+                                Part.part, 
+                                Part.quantity).\\
+                                    filter(Part.part=="our part").\\
+                                    cte(name="included_parts", recursive=True)
+
+            incl_alias = aliased(included_parts, name="pr")
+            parts_alias = aliased(Part, name="p")
+            included_parts = included_parts.union(
+                session.query(
+                    parts_alias.part, 
+                    parts_alias.sub_part, 
+                    parts_alias.quantity).\\
+                        filter(parts_alias.part==incl_alias.c.sub_part)
+                )
+
+            q = session.query(
+                    included_parts.c.sub_part,
+                    func.sum(included_parts.c.quantity).label('total_quantity')
+                ).\
+                group_by(included_parts.c.sub_part)
+
+        See also:
+        
+        :meth:`._SelectBase.cte`
+
+        """
+        return self.enable_eagerloads(False).statement.cte(name=name, recursive=recursive)
+
     def label(self, name):
         """Return the full SELECT statement represented by this :class:`.Query`, converted 
         to a scalar subquery with a label of the given name.

lib/sqlalchemy/sql/compiler.py

         # column targeting
         self.result_map = {}
 
+        # collect CTEs to tack on top of a SELECT
+        self.ctes = util.OrderedDict()
+        self.ctes_recursive = False
+
         # true if the paramstyle is positional
         self.positional = dialect.positional
         if self.positional:
         else:
             return self.bindtemplate % {'name':name}
 
+    def visit_cte(self, cte, asfrom=False, ashint=False, 
+                                fromhints=None, **kwargs):
+        if isinstance(cte.name, sql._truncated_label):
+            cte_name = self._truncated_identifier("alias", cte.name)
+        else:
+            cte_name = cte.name
+        if cte.cte_alias:
+            if isinstance(cte.cte_alias, sql._truncated_label):
+                cte_alias = self._truncated_identifier("alias", cte.cte_alias)
+            else:
+                cte_alias = cte.cte_alias
+        if not cte.cte_alias and cte not in self.ctes:
+            if cte.recursive:
+                self.ctes_recursive = True
+            text = self.preparer.format_alias(cte, cte_name)
+            if cte.recursive:
+                if isinstance(cte.original, sql.Select):
+                    col_source = cte.original
+                elif isinstance(cte.original, sql.CompoundSelect):
+                    col_source = cte.original.selects[0]
+                else:
+                    assert False
+                recur_cols = [c.key for c in util.unique_list(col_source.inner_columns)
+                                if c is not None]
+
+                text += "(%s)" % (", ".join(recur_cols))
+            text += " AS \n" + \
+                        cte.original._compiler_dispatch(
+                                self, asfrom=True, **kwargs
+                            )
+            self.ctes[cte] = text
+        if asfrom:
+            if cte.cte_alias:
+                text = self.preparer.format_alias(cte, cte_alias)
+                text += " AS " + cte_name
+            else:
+                return self.preparer.format_alias(cte, cte_name)
+            return text
+
     def visit_alias(self, alias, asfrom=False, ashint=False, 
                                 fromhints=None, **kwargs):
         if asfrom or ashint:
         if select.for_update:
             text += self.for_update_clause(select)
 
+        if self.ctes and \
+            compound_index==1 and not entry:
+            cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
+            cte_text += ", \n".join(
+                [txt for txt in self.ctes.values()]
+            )
+            cte_text += "\n "
+            text = cte_text + text
+
         self.stack.pop(-1)
 
         if asfrom and parens:
         else:
             return text
 
+    def get_cte_preamble(self, recursive):
+        if recursive:
+            return "WITH RECURSIVE"
+        else:
+            return "WITH"
+
     def get_select_precolumns(self, select):
         """Called when building a ``SELECT`` statement, position is just
         before column list.

lib/sqlalchemy/sql/expression.py

     def bind(self):
         return self.element.bind
 
+class CTE(Alias):
+    """Represent a Common Table Expression.
+    
+    The :class:`.CTE` object is obtained using the
+    :meth:`._SelectBase.cte` method from any selectable.
+    See that method for complete examples.
+    
+    New in 0.7.6.
+
+    """
+    __visit_name__ = 'cte'
+    def __init__(self, selectable, 
+                        name=None, 
+                        recursive=False, 
+                        cte_alias=False):
+        self.recursive = recursive
+        self.cte_alias = cte_alias
+        super(CTE, self).__init__(selectable, name=name)
+
+    def alias(self, name=None):
+        return CTE(
+            self.original,
+            name=name,
+            recursive=self.recursive,
+            cte_alias = self.name
+        )
+
+    def union(self, other):
+        return CTE(
+            self.original.union(other),
+            name=self.name,
+            recursive=self.recursive
+        )
+
+    def union_all(self, other):
+        return CTE(
+            self.original.union_all(other),
+            name=self.name,
+            recursive=self.recursive
+        )
+
 
 class _Grouping(ColumnElement):
     """Represent a grouping within a column expression"""
         """
         return self.as_scalar().label(name)
 
+    def cte(self, name=None, recursive=False):
+        """Return a new :class:`.CTE`, or Common Table Expression instance.
+        
+        Common table expressions are a SQL standard whereby SELECT
+        statements can draw upon secondary statements specified along
+        with the primary statement, using a clause called "WITH".
+        Special semantics regarding UNION can also be employed to 
+        allow "recursive" queries, where a SELECT statement can draw 
+        upon the set of rows that have previously been selected.
+        
+        SQLAlchemy detects :class:`.CTE` objects, which are treated
+        similarly to :class:`.Alias` objects, as special elements
+        to be delivered to the FROM clause of the statement as well
+        as to a WITH clause at the top of the statement.
+
+        The :meth:`._SelectBase.cte` method is new in 0.7.6.
+        
+        :param name: name given to the common table expression.  Like
+         :meth:`._FromClause.alias`, the name can be left as ``None``
+         in which case an anonymous symbol will be used at query
+         compile time.
+        :param recursive: if ``True``, will render ``WITH RECURSIVE``.
+         A recursive common table expression is intended to be used in 
+         conjunction with UNION or UNION ALL in order to derive rows
+         from those already selected.
+
+        The following examples illustrate two examples from 
+        Postgresql's documentation at
+        http://www.postgresql.org/docs/8.4/static/queries-with.html.
+        
+        Example 1, non recursive::
+        
+            from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+                select, func
+
+            metadata = MetaData()
+
+            orders = Table('orders', metadata,
+                Column('region', String),
+                Column('amount', Integer),
+                Column('product', String),
+                Column('quantity', Integer)
+            )
+
+            regional_sales = select([
+                                orders.c.region, 
+                                func.sum(orders.c.amount).label('total_sales')
+                            ]).group_by(orders.c.region).cte("regional_sales")
+
+
+            top_regions = select([regional_sales.c.region]).\\
+                    where(
+                        regional_sales.c.total_sales > 
+                        select([
+                            func.sum(regional_sales.c.total_sales)/10
+                        ])
+                    ).cte("top_regions")
+
+            statement = select([
+                        orders.c.region, 
+                        orders.c.product, 
+                        func.sum(orders.c.quantity).label("product_units"), 
+                        func.sum(orders.c.amount).label("product_sales")
+                ]).where(orders.c.region.in_(
+                    select([top_regions.c.region])
+                )).group_by(orders.c.region, orders.c.product)
+        
+            result = conn.execute(statement).fetchall()
+            
+        Example 2, WITH RECURSIVE::
+
+            from sqlalchemy import Table, Column, String, Integer, MetaData, \\
+                select, func
+
+            metadata = MetaData()
+
+            parts = Table('parts', metadata,
+                Column('part', String),
+                Column('sub_part', String),
+                Column('quantity', Integer),
+            )
+
+            included_parts = select([
+                                parts.c.sub_part, 
+                                parts.c.part, 
+                                parts.c.quantity]).\\
+                                where(parts.c.part=='our part').\\
+                                cte(recursive=True)
+
+
+            incl_alias = included_parts.alias()
+            parts_alias = parts.alias()
+            included_parts = included_parts.union(
+                select([
+                    parts_alias.c.part, 
+                    parts_alias.c.sub_part, 
+                    parts_alias.c.quantity
+                ]).
+                    where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+            statement = select([
+                        included_parts.c.sub_part, 
+                        func.sum(included_parts.c.quantity).label('total_quantity')
+                    ]).\
+                    select_from(included_parts.join(parts,
+                                included_parts.c.part==parts.c.part)).\\
+                    group_by(included_parts.c.sub_part)
+
+            result = conn.execute(statement).fetchall()
+
+        
+        See also:
+        
+        :meth:`.orm.query.Query.cte` - ORM version of :meth:`._SelectBase.cte`.
+
+        """
+        return CTE(self, name=name, recursive=recursive)
+
     @_generative
     @util.deprecated('0.6',
                      message=":func:`.autocommit` is deprecated. Use "

test/sql/test_compiler.py

             "SELECT x + foo() OVER () AS anon_1"
         )
 
+    def test_cte_nonrecursive(self):
+        orders = table('orders', 
+            column('region'),
+            column('amount'),
+            column('product'),
+            column('quantity')
+        )
+
+        regional_sales = select([
+                            orders.c.region, 
+                            func.sum(orders.c.amount).label('total_sales')
+                        ]).group_by(orders.c.region).cte("regional_sales")
+
+        top_regions = select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.total_sales > 
+                    select([
+                        func.sum(regional_sales.c.total_sales)/10
+                    ])
+                ).cte("top_regions")
+
+        s = select([
+                    orders.c.region, 
+                    orders.c.product, 
+                    func.sum(orders.c.quantity).label("product_units"), 
+                    func.sum(orders.c.amount).label("product_sales")
+            ]).where(orders.c.region.in_(
+                select([top_regions.c.region])
+            )).group_by(orders.c.region, orders.c.product)
+
+        # needs to render regional_sales first as top_regions
+        # refers to it
+        self.assert_compile(
+            s,
+            "WITH regional_sales AS (SELECT orders.region AS region, "
+            "sum(orders.amount) AS total_sales FROM orders "
+            "GROUP BY orders.region), "
+            "top_regions AS (SELECT "
+            "regional_sales.region AS region FROM regional_sales "
+            "WHERE regional_sales.total_sales > "
+            "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
+            "anon_1 FROM regional_sales)) "
+            "SELECT orders.region, orders.product, "
+            "sum(orders.quantity) AS product_units, "
+            "sum(orders.amount) AS product_sales "
+            "FROM orders WHERE orders.region "
+            "IN (SELECT top_regions.region FROM top_regions) "
+            "GROUP BY orders.region, orders.product"
+        )
+
+    def test_cte_recursive(self):
+        parts = table('parts', 
+            column('part'),
+            column('sub_part'),
+            column('quantity'),
+        )
+
+        included_parts = select([
+                            parts.c.sub_part, 
+                            parts.c.part, 
+                            parts.c.quantity]).\
+                            where(parts.c.part=='our part').\
+                                cte(recursive=True)
+
+        incl_alias = included_parts.alias()
+        parts_alias = parts.alias()
+        included_parts = included_parts.union(
+            select([
+                parts_alias.c.part, 
+                parts_alias.c.sub_part, 
+                parts_alias.c.quantity]).\
+                where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+        s = select([
+            included_parts.c.sub_part, 
+            func.sum(included_parts.c.quantity).label('total_quantity')]).\
+            select_from(included_parts.join(
+                    parts,included_parts.c.part==parts.c.part)).\
+            group_by(included_parts.c.sub_part)
+        self.assert_compile(s, 
+                "WITH RECURSIVE anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part"
+            )
+
+        # quick check that the "WITH RECURSIVE" varies per
+        # dialect
+        self.assert_compile(s, 
+                "WITH anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part",
+                dialect=mssql.dialect()
+            )
 
     def test_date_between(self):
         import datetime
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.