jason kirtland avatar jason kirtland committed cc92707

Allow auto_increment on any pk column, not just the first.

Comments (0)

Files changed (2)

lib/sqlalchemy/databases/mysql.py

         if not column.nullable:
             colspec.append('NOT NULL')
 
-        # FIXME: #649 ASAP
-        if column.primary_key:
-            if (len(column.foreign_keys)==0
-                and first_pk
-                and column.autoincrement
-                and isinstance(column.type, sqltypes.Integer)):
-                colspec.append('AUTO_INCREMENT')
+        if column.primary_key and column.autoincrement:
+            try:
+                first = [c for c in column.table.primary_key.columns
+                         if (c.autoincrement and
+                             isinstance(c.type, sqltypes.Integer) and
+                             not c.foreign_keys)].pop(0)
+                if column is first:
+                    colspec.append('AUTO_INCREMENT')
+            except IndexError:
+                pass
 
         return ' '.join(colspec)
 
         # AUTO_INCREMENT
         if spec.get('autoincr', False):
             col_kw['autoincrement'] = True
+        elif issubclass(col_type, sqltypes.Integer):
+            col_kw['autoincrement'] = False
 
         # DEFAULT
         default = spec.get('default', None)

test/dialect/mysql.py

 
         m.drop_all()
 
+    @testing.supported('mysql')
+    def test_autoincrement(self):
+        meta = MetaData(testbase.db)
+        try:
+            Table('ai_1', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True))
+            Table('ai_2', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True))
+            Table('ai_3', meta,
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_4', meta,
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False),
+                  Column('int_n2', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False))
+            Table('ai_5', meta,
+                  Column('int_y', Integer, primary_key=True),
+                  Column('int_n', Integer, PassiveDefault('0'),
+                         primary_key=True, autoincrement=False))
+            Table('ai_6', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_7', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('o2', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('int_y', Integer, primary_key=True))
+            Table('ai_8', meta,
+                  Column('o1', String(1), PassiveDefault('x'),
+                         primary_key=True),
+                  Column('o2', String(1), PassiveDefault('x'),
+                         primary_key=True))
+            meta.create_all()
+
+            table_names = ['ai_1', 'ai_2', 'ai_3', 'ai_4',
+                           'ai_5', 'ai_6', 'ai_7', 'ai_8']
+            mr = MetaData(testbase.db)
+            mr.reflect(only=table_names)
+
+            for tbl in [mr.tables[name] for name in table_names]:
+                for c in tbl.c:
+                    if c.name.startswith('int_y'):
+                        assert c.autoincrement
+                    elif c.name.startswith('int_n'):
+                        assert not c.autoincrement
+                tbl.insert().execute()
+                if 'int_y' in tbl.c:
+                    assert select([tbl.c.int_y]).scalar() == 1
+                    assert list(tbl.select().execute().fetchone()).count(1) == 1
+                else:
+                    assert 1 not in list(tbl.select().execute().fetchone())
+        finally:
+            meta.drop_all()
+
     def assert_eq(self, got, wanted):
         if got != wanted:
             print "Expected %s" % wanted
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.