Commits

idank committed c3efee6

compiler: adjust _get_colparams to return the columns and parameters in separate lists

Comments (0)

Files changed (1)

lib/sqlalchemy/sql/compiler.py

 
     def visit_insert(self, insert_stmt, **kw):
         self.isinsert = True
-        colparams = self._get_colparams(insert_stmt)
+        cols, params = self._get_colparams(insert_stmt)
 
-        if not colparams and \
+        if not cols and \
                 not self.dialect.supports_default_values and \
                 not self.dialect.supports_empty_insert:
             raise exc.CompileError("The version of %s you are using does "
 
         text += table_text
 
-        if colparams or not supports_default_values:
-            text += " (%s)" % ', '.join([preparer.format_column(c[0])
-                       for c in colparams])
+        if cols or not supports_default_values:
+            text += " (%s)" % ', '.join([preparer.format_column(c)
+                       for c in cols])
 
         if self.returning or insert_stmt._returning:
             self.returning = self.returning or insert_stmt._returning
             if self.returning_precedes_values:
                 text += " " + returning_clause
 
-        if not colparams and supports_default_values:
+        if not cols and supports_default_values:
             text += " DEFAULT VALUES"
         else:
             text += " VALUES (%s)" % \
-                     ', '.join([c[1] for c in colparams])
+                     ', '.join(params[0])
 
         if self.returning and not self.returning_precedes_values:
             text += " " + returning_clause
 
         extra_froms = update_stmt._extra_froms
 
-        colparams = self._get_colparams(update_stmt, extra_froms)
+        cols, params = self._get_colparams(update_stmt, extra_froms)
 
         text = "UPDATE "
 
         text += ' SET '
         include_table = extra_froms and \
                         self.render_table_with_column_in_update_from
+        colparams = []
+        if params:
+            colparams = zip(cols, params[0])
         text += ', '.join(
-                        c[0]._compiler_dispatch(self,
+                        c._compiler_dispatch(self,
                             include_table=include_table) +
-                        '=' + c[1] for c in colparams
+                        '=' + p for c, p in colparams
                         )
 
         if update_stmt._returning:
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         if self.column_keys is None and stmt.parameters is None:
-            return [
-                        (c, self._create_crud_bind_param(c,
-                                    None, required=True))
-                        for c in stmt.table.columns
-                    ]
+            values = [self._create_crud_bind_param(c, None, required=True)
+                      for c in stmt.table.columns]
+            return list(stmt.table.columns), [values]
 
         required = object()
 
                               key not in stmt.parameters)
 
         # create a list of column assignment clauses as tuples
+        columns = []
         values = []
 
         if stmt.parameters is not None:
                     else:
                         v = self.process(v.self_group())
 
-                    values.append((k, v))
+                    columns.append(k)
+                    values.append(v)
 
         need_pks = self.isinsert and \
                         not self.inline and \
                         else:
                             self.postfetch.append(c)
                             value = self.process(value.self_group())
-                        values.append((c, value))
+                        columns.append(c)
+                        values.append(value)
             # determine tables which are actually
             # to be updated - process onupdate and
             # server_onupdate for these
                         continue
                     elif c.onupdate is not None and not c.onupdate.is_sequence:
                         if c.onupdate.is_clause_element:
-                            values.append(
-                                (c, self.process(c.onupdate.arg.self_group()))
-                            )
+                            columns.apppend(c)
+                            values.append(self.process(c.onupdate.arg.self_group()))
                             self.postfetch.append(c)
                         else:
-                            values.append(
-                                (c, self._create_crud_bind_param(c, None))
-                            )
+                            columns.append(c)
+                            values.append(self._create_crud_bind_param(c, None))
                             self.prefetch.append(c)
                     elif c.server_onupdate is not None:
                         self.postfetch.append(c)
                 else:
                     self.postfetch.append(c)
                     value = self.process(value.self_group())
-                values.append((c, value))
+                columns.append(c)
+                values.append(value)
 
             elif self.isinsert:
                 if c.primary_key and \
                                     (not c.default.optional or \
                                     not self.dialect.sequences_optional):
                                     proc = self.process(c.default)
-                                    values.append((c, proc))
+                                    columns.append(c)
+                                    values.append(proc)
                                 self.returning.append(c)
                             elif c.default.is_clause_element:
-                                values.append(
-                                    (c,
-                                    self.process(c.default.arg.self_group()))
-                                )
+                                columns.append(c)
+                                values.append(self.process(c.default.arg.self_group()))
                                 self.returning.append(c)
                             else:
-                                values.append(
-                                    (c, self._create_crud_bind_param(c, None))
-                                )
+                                columns.append(c)
+                                values.append(self._create_crud_bind_param(c, None))
                                 self.prefetch.append(c)
                         else:
                             self.returning.append(c)
                                 self.dialect.preexecute_autoincrement_sequences
                             ):
 
-                            values.append(
-                                (c, self._create_crud_bind_param(c, None))
-                            )
-
+                            columns.append(c)
+                            values.append(self._create_crud_bind_param(c, None))
                             self.prefetch.append(c)
 
                 elif c.default is not None:
                             (not c.default.optional or \
                             not self.dialect.sequences_optional):
                             proc = self.process(c.default)
-                            values.append((c, proc))
+                            columns.append(c)
+                            values.append(proc)
                             if not c.primary_key:
                                 self.postfetch.append(c)
                     elif c.default.is_clause_element:
-                        values.append(
-                            (c, self.process(c.default.arg.self_group()))
-                        )
+                        columns.append(c)
+                        values.append(self.process(c.default.arg.self_group()))
 
                         if not c.primary_key:
                             # dont add primary key column to postfetch
                             self.postfetch.append(c)
                     else:
-                        values.append(
-                            (c, self._create_crud_bind_param(c, None))
-                        )
+                        columns.append(c)
+                        values.append(self._create_crud_bind_param(c, None))
                         self.prefetch.append(c)
                 elif c.server_default is not None:
                     if not c.primary_key:
             elif self.isupdate:
                 if c.onupdate is not None and not c.onupdate.is_sequence:
                     if c.onupdate.is_clause_element:
-                        values.append(
-                            (c, self.process(c.onupdate.arg.self_group()))
-                        )
+                        columns.append(c)
+                        values.append(self.process(c.onupdate.arg.self_group()))
                         self.postfetch.append(c)
                     else:
-                        values.append(
-                            (c, self._create_crud_bind_param(c, None))
-                        )
+                        columns.append(c)
+                        values.append(self._create_crud_bind_param(c, None))
                         self.prefetch.append(c)
                 elif c.server_onupdate is not None:
                     self.postfetch.append(c)
                     (", ".join(check))
                 )
 
-        return values
+        if values:
+            values = [values]
+
+        return columns, values
 
     def visit_delete(self, delete_stmt, **kw):
         self.stack.append({'from': set([delete_stmt.table])})