Commits

Mike Bayer committed 94e6c80

- The compiler extension now supports overriding the default
compilation of expression._BindParamClause including that
the auto-generated binds within the VALUES/SET clause
of an insert()/update() statement will also use the new
compilation rules. [ticket:2042]

Comments (0)

Files changed (6)

     the extension compiles and runs on Python 2.4.
     [ticket:2023]
 
+  - The compiler extension now supports overriding the default
+    compilation of expression._BindParamClause including that
+    the auto-generated binds within the VALUES/SET clause
+    of an insert()/update() statement will also use the new
+    compilation rules. [ticket:2042]
+
 - postgresql
   - When explicit sequence execution derives the name 
     of the auto-generated sequence of a SERIAL column, 

lib/sqlalchemy/sql/compiler.py

         else:
             return fn(" " + operator + " ")
 
-    def visit_bindparam(self, bindparam, within_columns_clause=False, 
+    def visit_bindparam(self, bindparam, within_columns_clause=False,
                                             literal_binds=False, **kwargs):
+
         if literal_binds or \
             (within_columns_clause and \
                 self.ansi_bind_rules):
                             within_columns_clause=True, **kwargs)
 
         name = self._truncate_bindparam(bindparam)
+
         if name in self.binds:
             existing = self.binds[name]
             if existing is not bindparam:
                             "unique bind parameter of the same name" %
                             bindparam.key
                         )
-                elif getattr(existing, '_is_crud', False):
+                elif getattr(existing, '_is_crud', False) or \
+                    getattr(bindparam, '_is_crud', False):
                     raise exc.CompileError(
                         "bindparam() name '%s' is reserved "
                         "for automatic usage in the VALUES or SET "
         bindparam = sql.bindparam(col.key, value, 
                             type_=col.type, required=required)
         bindparam._is_crud = True
-        if col.key in self.binds:
-            raise exc.CompileError(
-                    "bindparam() name '%s' is reserved "
-                    "for automatic usage in the VALUES or SET clause of this "
-                    "insert/update statement.   Please use a " 
-                    "name other than column name when using bindparam() "
-                    "with insert() or update() (for example, 'b_%s')."
-                    % (col.key, col.key)
-                )
+        return bindparam._compiler_dispatch(self)
 
-        self.binds[col.key] = bindparam
-        return self.bindparam_string(self._truncate_bindparam(bindparam))
 
     def _get_colparams(self, stmt):
         """create a set of tuples representing column/string pairs for use

lib/sqlalchemy/sql/visitors.py

             super(VisitableType, cls).__init__(clsname, bases, clsdict)
             return
 
-        # set up an optimized visit dispatch function
-        # for use by the compiler
-        if '__visit_name__' in cls.__dict__:
-            visit_name = cls.__visit_name__
-            if isinstance(visit_name, str):
-                getter = operator.attrgetter("visit_%s" % visit_name)
-                def _compiler_dispatch(self, visitor, **kw):
-                    return getter(visitor)(self, **kw)
-            else:
-                def _compiler_dispatch(self, visitor, **kw):
-                    return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
-
-            cls._compiler_dispatch = _compiler_dispatch
+        _generate_dispatch(cls)
 
         super(VisitableType, cls).__init__(clsname, bases, clsdict)
 
+def _generate_dispatch(cls):
+    # set up an optimized visit dispatch function
+    # for use by the compiler
+    if '__visit_name__' in cls.__dict__:
+        visit_name = cls.__visit_name__
+        if isinstance(visit_name, str):
+            getter = operator.attrgetter("visit_%s" % visit_name)
+            def _compiler_dispatch(self, visitor, **kw):
+                return getter(visitor)(self, **kw)
+        else:
+            def _compiler_dispatch(self, visitor, **kw):
+                return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw)
+
+        cls._compiler_dispatch = _compiler_dispatch
+
 class Visitable(object):
     """Base class for visitable objects, applies the
     ``VisitableType`` metaclass.

test/aaa_profiling/test_compiler.py

 
         cls.dialect = default.DefaultDialect()
 
-    @profiling.function_call_count(versions={'2.7':58, '2.6':58,
-                                            '3':64})
+    @profiling.function_call_count(versions={'2.7':62, '2.6':62,
+                                            '3':68})
     def test_insert(self):
         t1.insert().compile(dialect=self.dialect)
 
-    @profiling.function_call_count(versions={'2.6':49, '2.7':49})
+    @profiling.function_call_count(versions={'2.6':53, '2.7':53})
     def test_update(self):
         t1.update().compile(dialect=self.dialect)
 

test/ext/test_compiler.py

 from sqlalchemy import *
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.sql.expression import ClauseElement, ColumnClause,\
-                                    FunctionElement, Select
+                                    FunctionElement, Select, \
+                                    _BindParamClause
+
 from sqlalchemy.schema import DDLElement
 from sqlalchemy.ext.compiler import compiles
-from sqlalchemy.sql import table, column
+from sqlalchemy.sql import table, column, visitors
 from test.lib import *
 
 class UserDefinedTest(TestBase, AssertsCompiledSQL):
             if hasattr(Select, '_compiler_dispatcher'):
                 del Select._compiler_dispatcher
 
-    def test_default_on_existing(self):
-        """test that the existing compiler function remains
-        as 'default' when overriding the compilation of an
-        existing construct."""
-
-
-        t1 = table('t1', column('c1'), column('c2'))
-
-        dispatch = Select._compiler_dispatch
-        try:
-
-            @compiles(Select, 'sqlite')
-            def compile(element, compiler, **kw):
-                return "OVERRIDE"
-
-            s1 = select([t1])
-            self.assert_compile(
-                s1, "SELECT t1.c1, t1.c2 FROM t1",
-            )
-
-            from sqlalchemy.dialects.sqlite import base as sqlite
-            self.assert_compile(
-                s1, "OVERRIDE",
-                dialect=sqlite.dialect()
-            )
-        finally:
-            Select._compiler_dispatch = dispatch
-            if hasattr(Select, '_compiler_dispatcher'):
-                del Select._compiler_dispatcher
-
     def test_dialect_specific(self):
         class AddThingy(DDLElement):
             __visit_name__ = 'add_thingy'
             'SELECT FOOsub1, sub2, FOOsubsub1',
             use_default_dialect=True
         )
+
+
+class DefaultOnExistingTest(TestBase, AssertsCompiledSQL):
+    """Test replacement of default compilation on existing constructs."""
+
+    def teardown(self):
+        for cls in (Select, _BindParamClause):
+            if hasattr(cls, '_compiler_dispatcher'):
+                visitors._generate_dispatch(cls)
+                del cls._compiler_dispatcher
+
+    def test_select(self):
+        t1 = table('t1', column('c1'), column('c2'))
+
+        @compiles(Select, 'sqlite')
+        def compile(element, compiler, **kw):
+            return "OVERRIDE"
+
+        s1 = select([t1])
+        self.assert_compile(
+            s1, "SELECT t1.c1, t1.c2 FROM t1",
+        )
+
+        from sqlalchemy.dialects.sqlite import base as sqlite
+        self.assert_compile(
+            s1, "OVERRIDE",
+            dialect=sqlite.dialect()
+        )
+
+    def test_binds_in_select(self):
+        t = table('t',
+            column('a'),
+            column('b'),
+            column('c')
+        )
+
+        @compiles(_BindParamClause)
+        def gen_bind(element, compiler, **kw):
+            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+        self.assert_compile(
+            t.select().where(t.c.c == 5), 
+            "SELECT t.a, t.b, t.c FROM t WHERE t.c = BIND(:c_1)",
+            use_default_dialect=True
+        )
+
+    def test_binds_in_dml(self):
+        t = table('t',
+            column('a'),
+            column('b'),
+            column('c')
+        )
+
+        @compiles(_BindParamClause)
+        def gen_bind(element, compiler, **kw):
+            return "BIND(%s)" % compiler.visit_bindparam(element, **kw)
+
+        self.assert_compile(
+            t.insert(), 
+            "INSERT INTO t (a, b) VALUES (BIND(:a), BIND(:b))",
+            {'a':1, 'b':2},
+            use_default_dialect=True
+        )

test/lib/testing.py

         assert val, msg
 
 class AssertsCompiledSQL(object):
-    def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None, use_default_dialect=False):
+    def assert_compile(self, clause, result, params=None, 
+                        checkparams=None, dialect=None, 
+                        use_default_dialect=False):
         if use_default_dialect:
             dialect = default.DefaultDialect()
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.