Commits

idank committed 69bd0f8

compiler: add support for multirow inserts

Some databases support this syntax for inserts:

INSERT INTO table (col1, col2) VALUES
('v1', 'v2'),
('v3', 'v4');

which greatly increases INSERT speed.

It is now possible to pass a list of lists/tuples/dictionaries as
the values param to the Insert construct. We convert it to a flat
dictionary so we can continue using bind params. The above query
will be converted to:

INSERT INTO table (col1, col2) VALUES
(:col10, :col20),
(:col11, :col21);

Currently only supported on postgresql and mysql.

Comments (0)

Files changed (8)

lib/sqlalchemy/dialects/mysql/base.py

 
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = False
+    supports_multirow_insert = True
 
     default_paramstyle = 'format'
     colspecs = colspecs

lib/sqlalchemy/dialects/postgresql/base.py

 
     supports_default_values = True
     supports_empty_insert = False
+    supports_multirow_insert = True
     default_paramstyle = 'pyformat'
     ischema_names = ischema_names
     colspecs = colspecs

lib/sqlalchemy/dialects/sqlite/base.py

     supports_default_values = True
     supports_empty_insert = False
     supports_cast = True
+    supports_multirow_insert = True
 
     default_paramstyle = 'qmark'
     execution_ctx_cls = SQLiteExecutionContext

lib/sqlalchemy/engine/default.py

     default_paramstyle = 'named'
     supports_default_values = False
     supports_empty_insert = True
+    supports_multirow_insert = False
 
     server_version_info = None
 

lib/sqlalchemy/sql/compiler.py

                                     "not support empty inserts." %
                                     self.dialect.name)
 
+        if self.multirow and not self.dialect.supports_multirow_insert:
+            raise exc.CompileError("The version of %s you are using does "
+                                    "not support multirow inserts." %
+                                    self.dialect.name)
+
         preparer = self.preparer
         supports_default_values = self.dialect.supports_default_values
 
 
         if not colparams and supports_default_values:
             text += " DEFAULT VALUES"
+        elif self.multirow:
+            values = []
+            for z in itertools.izip(*[c[1] for c in colparams]):
+                values.append('(%s)' % ', '.join(i for i in z))
+            text += " VALUES %s" % ', '.join(values)
         else:
             text += " VALUES (%s)" % \
                      ', '.join([c[1] for c in colparams])
 
         return text
 
-    def _create_crud_bind_param(self, col, value, required=False):
-        bindparam = sql.bindparam(col.key, value,
+    def _create_crud_bind_param(self, col, value, required=False, name=None):
+        if name is None:
+            name = col.key
+        bindparam = sql.bindparam(name, value,
                             type_=col.type, required=required,
                             quote=col.quote)
         bindparam._is_crud = True
         self.postfetch = []
         self.prefetch = []
         self.returning = []
+        self.multirow = False
 
         # no parameters in the statement, no parameters in the
         # compiled params - return binds for all columns
         # iterating through columns at the top to maintain ordering.
         # otherwise we might iterate through individual sets of
         # "defaults", "primary key cols", etc.
+        leftovers = []
         for c in stmt.table.columns:
             if c.key in parameters and c.key not in check_columns:
                 value = parameters.pop(c.key)
                 if sql._is_literal(value):
-                    value = self._create_crud_bind_param(
-                                    c, value, required=value is required)
+                    if isinstance(value, list):
+                        self.multirow = True
+                        # handle only the first value here -- since the
+                        # order of bindparams creation matters for positional
+                        # parameters, we must process the values as rows rather
+                        # than columns
+                        singlevalue = value.pop(0)
+                        if value:
+                            # save the rest of the values of this column
+                            # for later
+                            leftovers.append(value)
+                        value = [self._create_crud_bind_param(
+                                    c, singlevalue,
+                                    required=value is required,
+                                    name=c.key + '0')]
+                    else:
+                        value = self._create_crud_bind_param(
+                                        c, value, required=value is required)
                 elif c.primary_key and implicit_returning:
                     self.returning.append(c)
                     value = self.process(value.self_group())
                     (", ".join(check))
                 )
 
+        if leftovers:
+            for row in zip(*leftovers):
+                for i, value in enumerate(row):
+                    c, l = values[i]
+                    l.append(self._create_crud_bind_param(
+                                c, value,
+                                required=False,
+                                name=c.key + str(len(l))))
+
         return values
 
     def visit_delete(self, delete_stmt, **kw):

lib/sqlalchemy/sql/expression.py

     def _process_colparams(self, parameters):
         if isinstance(parameters, (list, tuple)):
             pp = {}
-            for i, c in enumerate(self.table.c):
-                pp[c.key] = parameters[i]
+            if isinstance(parameters[0], (list, tuple)):
+               for i, c in enumerate(self.table.c):
+                   pp[c.key] = [p[i] for p in parameters]
+            elif isinstance(parameters[0], dict):
+               for k in parameters[0].keys():
+                   pp[k] = [p[k] for p in parameters]
+            else:
+               for i, c in enumerate(self.table.c):
+                   pp[c.key] = parameters[i]
             return pp
         else:
             return parameters

test/sql/test_compiler.py

                     table.insert(inline=True),
                     "INSERT INTO sometable (foo) VALUES (foobar())", params={})
 
+    def test_multirow_insert(self):
+        data = [(1, 'a', 'b'), (2, 'a', 'b')]
+        result = "INSERT INTO mytable (myid, name, description) VALUES " \
+                 "(%(myid0)s, %(name0)s, %(description0)s), " \
+                 "(%(myid1)s, %(name1)s, %(description1)s)"
+
+        stmt = insert(table1, data, dialect='postgresql')
+        self.assert_compile(stmt, result, dialect=postgresql.dialect())
+
+        stmt = table1.insert(values=data, dialect='postgresql')
+        self.assert_compile(stmt, result, dialect=postgresql.dialect())
+
+        stmt = table1.insert(dialect='postgresql').values(data)
+        self.assert_compile(stmt, result, dialect=postgresql.dialect())
+
     def test_update(self):
         self.assert_compile(
                 update(table1, table1.c.myid == 7),

test/sql/test_query.py

     def teardown_class(cls):
         metadata.drop_all()
 
+    def test_multirow_insert(self):
+        users.insert(values=[{'user_id':7, 'user_name':'jack'},
+            {'user_id':8, 'user_name':'ed'}]).execute()
+        rows = users.select().execute().fetchall()
+        self.assert_(rows[0] == (7, 'jack'))
+        self.assert_(rows[1] == (8, 'ed'))
+        users.insert(values=[(9, 'jack'), (10, 'ed')]).execute()
+        rows = users.select().execute().fetchall()
+        self.assert_(rows[2] == (9, 'jack'))
+        self.assert_(rows[3] == (10, 'ed'))
+
     def test_insert_heterogeneous_params(self):
         """test that executemany parameters are asserted to match the
         parameter set of the first."""