Commits

Mike Bayer  committed d59aa93

- move cte tests into their own test/sql/test_cte.py
- rework bindtemplate system of "numbered" params by applying
the numbers last, as we now need to generate these out of order
in some cases
- add positional assertion to assert_compile
- add new cte_positional collection to track bindparams generated
within cte visits; splice this onto the beginning of self.positiontup
at cte render time, [ticket:2521]

  • Participants
  • Parent commits 3c5a9c4

Comments (0)

Files changed (5)

 0.7.9
 =====
 - sql
+  - [bug] Fixed CTE bug whereby positional 
+    bound parameters present in the CTEs themselves
+    would corrupt the overall ordering of 
+    bound parameters.  This primarily
+    affected SQL Server as the platform with 
+    positional binds + CTE support.  
+    [ticket:2521]
+
   - [bug] quoting is applied to the column names
     inside the WITH RECURSIVE clause of a 
     common table expression according to the 

File lib/sqlalchemy/sql/compiler.py

     operators, functions, util as sql_util, visitors, expression as sql
 )
 import decimal
+import itertools
 
 RESERVED_WORDS = set([
     'all', 'analyse', 'analyze', 'and', 'any', 'array',
     'pyformat':"%%(%(name)s)s",
     'qmark':"?",
     'format':"%%s",
-    'numeric':":%(position)s",
+    'numeric':":[_POSITION]",
     'named':":%(name)s"
 }
 
         # column targeting
         self.result_map = {}
 
-        # collect CTEs to tack on top of a SELECT
-        self.ctes = util.OrderedDict()
-        self.ctes_recursive = False
-
         # true if the paramstyle is positional
         self.positional = dialect.positional
         if self.positional:
             self.positiontup = []
         self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
 
+        # collect CTEs to tack on top of a SELECT
+        self.ctes = util.OrderedDict()
+        self.ctes_recursive = False
+        if self.positional:
+            self.cte_positional = []
+
         # an IdentifierPreparer that formats the quoting of identifiers
         self.preparer = dialect.identifier_preparer
         self.label_length = dialect.label_length \
         self.truncated_names = {}
         engine.Compiled.__init__(self, dialect, statement, **kwargs)
 
+        if self.positional and dialect.paramstyle == 'numeric':
+            self._apply_numbered_params()
 
+    def _apply_numbered_params(self):
+        poscount = itertools.count(1)
+        self.string = re.sub(
+                        r'\[_POSITION\]', 
+                        lambda m:str(next(poscount)), 
+                        self.string)
 
     @util.memoized_property
     def _bind_processors(self):
             if name in textclause.bindparams:
                 return self.process(textclause.bindparams[name])
             else:
-                return self.bindparam_string(name)
+                return self.bindparam_string(name, **kwargs)
 
         # un-escape any \:params
         return BIND_PARAMS_ESC.sub(lambda m: m.group(1),
 
         self.binds[bindparam.key] = self.binds[name] = bindparam
 
-        return self.bindparam_string(name, quote=bindparam.quote)
+        return self.bindparam_string(name, quote=bindparam.quote, **kwargs)
 
     def render_literal_bindparam(self, bindparam, **kw):
         value = bindparam.value
         self.anon_map[derived] = anonymous_counter + 1
         return derived + "_" + str(anonymous_counter)
 
-    def bindparam_string(self, name, quote=None):
+    def bindparam_string(self, name, quote=None, 
+                        positional_names=None, **kw):
         if self.positional:
-            self.positiontup.append(name)
-            return self.bindtemplate % {
-                        'name':name, 'position':len(self.positiontup)}
-        else:
-            return self.bindtemplate % {'name':name}
+            if positional_names is not None:
+                positional_names.append(name)
+            else:
+                self.positiontup.append(name)
+        return self.bindtemplate % {'name':name}
 
     def visit_cte(self, cte, asfrom=False, ashint=False, 
                                 fromhints=None, **kwargs):
+        if self.positional:
+            kwargs['positional_names'] = self.cte_positional
+
         if isinstance(cte.name, sql._truncated_label):
             cte_name = self._truncated_identifier("alias", cte.name)
         else:
             cte_name = cte.name
+
         if cte.cte_alias:
             if isinstance(cte.cte_alias, sql._truncated_label):
                 cte_alias = self._truncated_identifier("alias", cte.cte_alias)
 
     def visit_select(self, select, asfrom=False, parens=True, 
                             iswrapper=False, fromhints=None, 
-                            compound_index=1, **kwargs):
+                            compound_index=1, 
+                            positional_names=None, **kwargs):
 
         entry = self.stack and self.stack[-1] or {}
 
                           : iswrapper})
 
         if compound_index==1 and not entry or entry.get('iswrapper', False):
-            column_clause_args = {'result_map':self.result_map}
+            column_clause_args = {'result_map':self.result_map, 
+                                    'positional_names':positional_names}
         else:
-            column_clause_args = {}
+            column_clause_args = {'positional_names':positional_names}
 
         # the actual list of columns to print in the SELECT column list.
         inner_columns = [
             return text
 
     def _render_cte_clause(self):
+        if self.positional:
+            self.positiontup = self.cte_positional + self.positiontup
         cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
         cte_text += ", \n".join(
             [txt for txt in self.ctes.values()]

File test/lib/testing.py

 class AssertsCompiledSQL(object):
     def assert_compile(self, clause, result, params=None, 
                         checkparams=None, dialect=None, 
+                        checkpositional=None,
                         use_default_dialect=False,
                         allow_dialect_select=False):
         if use_default_dialect:
 
         if checkparams is not None:
             eq_(c.construct_params(params), checkparams)
+        if checkpositional is not None:
+            p = c.construct_params(params)
+            eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
 
 class ComparesTables(object):
     def assert_tables_equal(self, table, reflected_table, strict_types=False):

File test/sql/test_compiler.py

             "SELECT x + foo() OVER () AS anon_1"
         )
 
-    def test_cte_nonrecursive(self):
-        orders = table('orders', 
-            column('region'),
-            column('amount'),
-            column('product'),
-            column('quantity')
-        )
-
-        regional_sales = select([
-                            orders.c.region, 
-                            func.sum(orders.c.amount).label('total_sales')
-                        ]).group_by(orders.c.region).cte("regional_sales")
-
-        top_regions = select([regional_sales.c.region]).\
-                where(
-                    regional_sales.c.total_sales > 
-                    select([
-                        func.sum(regional_sales.c.total_sales)/10
-                    ])
-                ).cte("top_regions")
-
-        s = select([
-                    orders.c.region, 
-                    orders.c.product, 
-                    func.sum(orders.c.quantity).label("product_units"), 
-                    func.sum(orders.c.amount).label("product_sales")
-            ]).where(orders.c.region.in_(
-                select([top_regions.c.region])
-            )).group_by(orders.c.region, orders.c.product)
-
-        # needs to render regional_sales first as top_regions
-        # refers to it
-        self.assert_compile(
-            s,
-            "WITH regional_sales AS (SELECT orders.region AS region, "
-            "sum(orders.amount) AS total_sales FROM orders "
-            "GROUP BY orders.region), "
-            "top_regions AS (SELECT "
-            "regional_sales.region AS region FROM regional_sales "
-            "WHERE regional_sales.total_sales > "
-            "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
-            "anon_1 FROM regional_sales)) "
-            "SELECT orders.region, orders.product, "
-            "sum(orders.quantity) AS product_units, "
-            "sum(orders.amount) AS product_sales "
-            "FROM orders WHERE orders.region "
-            "IN (SELECT top_regions.region FROM top_regions) "
-            "GROUP BY orders.region, orders.product"
-        )
-
-    def test_cte_recursive(self):
-        parts = table('parts', 
-            column('part'),
-            column('sub_part'),
-            column('quantity'),
-        )
-
-        included_parts = select([
-                            parts.c.sub_part, 
-                            parts.c.part, 
-                            parts.c.quantity]).\
-                            where(parts.c.part=='our part').\
-                                cte(recursive=True)
-
-        incl_alias = included_parts.alias()
-        parts_alias = parts.alias()
-        included_parts = included_parts.union(
-            select([
-                parts_alias.c.part, 
-                parts_alias.c.sub_part, 
-                parts_alias.c.quantity]).\
-                where(parts_alias.c.part==incl_alias.c.sub_part)
-            )
-
-        s = select([
-            included_parts.c.sub_part, 
-            func.sum(included_parts.c.quantity).label('total_quantity')]).\
-            select_from(included_parts.join(
-                    parts,included_parts.c.part==parts.c.part)).\
-            group_by(included_parts.c.sub_part)
-        self.assert_compile(s, 
-                "WITH RECURSIVE anon_1(sub_part, part, quantity) "
-                "AS (SELECT parts.sub_part AS sub_part, parts.part "
-                "AS part, parts.quantity AS quantity FROM parts "
-                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
-                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
-                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
-                "WHERE parts_1.part = anon_2.sub_part) "
-                "SELECT anon_1.sub_part, "
-                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
-                "JOIN parts ON anon_1.part = parts.part "
-                "GROUP BY anon_1.sub_part"
-            )
-
-        # quick check that the "WITH RECURSIVE" varies per
-        # dialect
-        self.assert_compile(s, 
-                "WITH anon_1(sub_part, part, quantity) "
-                "AS (SELECT parts.sub_part AS sub_part, parts.part "
-                "AS part, parts.quantity AS quantity FROM parts "
-                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
-                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
-                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
-                "WHERE parts_1.part = anon_2.sub_part) "
-                "SELECT anon_1.sub_part, "
-                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
-                "JOIN parts ON anon_1.part = parts.part "
-                "GROUP BY anon_1.sub_part",
-                dialect=mssql.dialect()
-            )
-
-    def test_cte_union(self):
-        orders = table('orders', 
-            column('region'),
-            column('amount'),
-        )
-
-        regional_sales = select([
-                            orders.c.region,
-                            orders.c.amount
-                        ]).cte("regional_sales")
-
-        s = select([regional_sales.c.region]).\
-                where(
-                    regional_sales.c.amount > 500
-                )
-
-        self.assert_compile(s, 
-            "WITH regional_sales AS "
-            "(SELECT orders.region AS region, "
-            "orders.amount AS amount FROM orders) "
-            "SELECT regional_sales.region "
-            "FROM regional_sales WHERE "
-            "regional_sales.amount > :amount_1")
-
-        s = s.union_all(
-            select([regional_sales.c.region]).\
-                where(
-                    regional_sales.c.amount < 300
-                )
-        )
-        self.assert_compile(s, 
-            "WITH regional_sales AS "
-            "(SELECT orders.region AS region, "
-            "orders.amount AS amount FROM orders) "
-            "SELECT regional_sales.region FROM regional_sales "
-            "WHERE regional_sales.amount > :amount_1 "
-            "UNION ALL SELECT regional_sales.region "
-            "FROM regional_sales WHERE "
-            "regional_sales.amount < :amount_2")
-
-    def test_cte_reserved_quote(self):
-        orders = table('orders', 
-            column('order'),
-        )
-        s = select([orders.c.order]).cte("regional_sales", recursive=True)
-        s = select([s.c.order])
-        self.assert_compile(s,
-            'WITH RECURSIVE regional_sales("order") AS '
-            '(SELECT orders."order" AS "order" '
-            "FROM orders)"
-            ' SELECT regional_sales."order" '
-            "FROM regional_sales"
-            )
 
     def test_date_between(self):
         import datetime

File test/sql/test_cte.py

+from test.lib import fixtures
+from test.lib.testing import AssertsCompiledSQL
+from sqlalchemy.sql import table, column, select, func, literal
+from sqlalchemy.dialects import mssql
+from sqlalchemy.engine import default
+
+class CTETest(fixtures.TestBase, AssertsCompiledSQL):
+
+    __dialect__ = 'default'
+
+    def test_nonrecursive(self):
+        orders = table('orders', 
+            column('region'),
+            column('amount'),
+            column('product'),
+            column('quantity')
+        )
+
+        regional_sales = select([
+                            orders.c.region, 
+                            func.sum(orders.c.amount).label('total_sales')
+                        ]).group_by(orders.c.region).cte("regional_sales")
+
+        top_regions = select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.total_sales > 
+                    select([
+                        func.sum(regional_sales.c.total_sales)/10
+                    ])
+                ).cte("top_regions")
+
+        s = select([
+                    orders.c.region, 
+                    orders.c.product, 
+                    func.sum(orders.c.quantity).label("product_units"), 
+                    func.sum(orders.c.amount).label("product_sales")
+            ]).where(orders.c.region.in_(
+                select([top_regions.c.region])
+            )).group_by(orders.c.region, orders.c.product)
+
+        # needs to render regional_sales first as top_regions
+        # refers to it
+        self.assert_compile(
+            s,
+            "WITH regional_sales AS (SELECT orders.region AS region, "
+            "sum(orders.amount) AS total_sales FROM orders "
+            "GROUP BY orders.region), "
+            "top_regions AS (SELECT "
+            "regional_sales.region AS region FROM regional_sales "
+            "WHERE regional_sales.total_sales > "
+            "(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
+            "anon_1 FROM regional_sales)) "
+            "SELECT orders.region, orders.product, "
+            "sum(orders.quantity) AS product_units, "
+            "sum(orders.amount) AS product_sales "
+            "FROM orders WHERE orders.region "
+            "IN (SELECT top_regions.region FROM top_regions) "
+            "GROUP BY orders.region, orders.product"
+        )
+
+    def test_recursive(self):
+        parts = table('parts', 
+            column('part'),
+            column('sub_part'),
+            column('quantity'),
+        )
+
+        included_parts = select([
+                            parts.c.sub_part, 
+                            parts.c.part, 
+                            parts.c.quantity]).\
+                            where(parts.c.part=='our part').\
+                                cte(recursive=True)
+
+        incl_alias = included_parts.alias()
+        parts_alias = parts.alias()
+        included_parts = included_parts.union(
+            select([
+                parts_alias.c.part, 
+                parts_alias.c.sub_part, 
+                parts_alias.c.quantity]).\
+                where(parts_alias.c.part==incl_alias.c.sub_part)
+            )
+
+        s = select([
+            included_parts.c.sub_part, 
+            func.sum(included_parts.c.quantity).label('total_quantity')]).\
+            select_from(included_parts.join(
+                    parts,included_parts.c.part==parts.c.part)).\
+            group_by(included_parts.c.sub_part)
+        self.assert_compile(s, 
+                "WITH RECURSIVE anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part"
+            )
+
+        # quick check that the "WITH RECURSIVE" varies per
+        # dialect
+        self.assert_compile(s, 
+                "WITH anon_1(sub_part, part, quantity) "
+                "AS (SELECT parts.sub_part AS sub_part, parts.part "
+                "AS part, parts.quantity AS quantity FROM parts "
+                "WHERE parts.part = :part_1 UNION SELECT parts_1.part "
+                "AS part, parts_1.sub_part AS sub_part, parts_1.quantity "
+                "AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
+                "WHERE parts_1.part = anon_2.sub_part) "
+                "SELECT anon_1.sub_part, "
+                "sum(anon_1.quantity) AS total_quantity FROM anon_1 "
+                "JOIN parts ON anon_1.part = parts.part "
+                "GROUP BY anon_1.sub_part",
+                dialect=mssql.dialect()
+            )
+
+    def test_union(self):
+        orders = table('orders', 
+            column('region'),
+            column('amount'),
+        )
+
+        regional_sales = select([
+                            orders.c.region,
+                            orders.c.amount
+                        ]).cte("regional_sales")
+
+        s = select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.amount > 500
+                )
+
+        self.assert_compile(s, 
+            "WITH regional_sales AS "
+            "(SELECT orders.region AS region, "
+            "orders.amount AS amount FROM orders) "
+            "SELECT regional_sales.region "
+            "FROM regional_sales WHERE "
+            "regional_sales.amount > :amount_1")
+
+        s = s.union_all(
+            select([regional_sales.c.region]).\
+                where(
+                    regional_sales.c.amount < 300
+                )
+        )
+        self.assert_compile(s, 
+            "WITH regional_sales AS "
+            "(SELECT orders.region AS region, "
+            "orders.amount AS amount FROM orders) "
+            "SELECT regional_sales.region FROM regional_sales "
+            "WHERE regional_sales.amount > :amount_1 "
+            "UNION ALL SELECT regional_sales.region "
+            "FROM regional_sales WHERE "
+            "regional_sales.amount < :amount_2")
+
+    def test_reserved_quote(self):
+        orders = table('orders', 
+            column('order'),
+        )
+        s = select([orders.c.order]).cte("regional_sales", recursive=True)
+        s = select([s.c.order])
+        self.assert_compile(s,
+            'WITH RECURSIVE regional_sales("order") AS '
+            '(SELECT orders."order" AS "order" '
+            "FROM orders)"
+            ' SELECT regional_sales."order" '
+            "FROM regional_sales"
+            )
+
+    def test_positional_binds(self):
+        orders = table('orders', 
+            column('order'),
+        )
+        s = select([orders.c.order, literal("x")]).cte("regional_sales")
+        s = select([s.c.order, literal("y")])
+        dialect = default.DefaultDialect()
+        dialect.positional = True
+        dialect.paramstyle = 'numeric'
+        self.assert_compile(s,
+            'WITH regional_sales AS (SELECT orders."order" '
+            'AS "order", :1 AS anon_2 FROM orders) SELECT '
+            'regional_sales."order", :2 AS anon_1 FROM regional_sales',
+            checkpositional=('x', 'y'),
+            dialect=dialect
+        )
+
+        self.assert_compile(s.union(s),
+            'WITH regional_sales AS (SELECT orders."order" '
+            'AS "order", :1 AS anon_2 FROM orders) SELECT '
+            'regional_sales."order", :2 AS anon_1 FROM regional_sales '
+            'UNION SELECT regional_sales."order", :3 AS anon_1 '
+            'FROM regional_sales',
+            checkpositional=('x', 'y', 'y'),
+            dialect=dialect
+        )
+
+        s = select([orders.c.order]).\
+            where(orders.c.order=='x').cte("regional_sales")
+        s = select([s.c.order]).where(s.c.order=="y")
+        self.assert_compile(s,
+            'WITH regional_sales AS (SELECT orders."order" AS '
+            '"order" FROM orders WHERE orders."order" = :1) '
+            'SELECT regional_sales."order" FROM regional_sales '
+            'WHERE regional_sales."order" = :2',
+            checkpositional=('x', 'y'),
+            dialect=dialect
+        )