Commits

Andrew Godwin  committed d310fe2 Merge

Merged in beyondwords/south-mysql (pull request #21)

  • Participants
  • Parent commits 1924f66, 577f264

Comments (0)

Files changed (4)

File south/db/generic.py

     def __repr__(self):
         return 'INVALID'
 
+class DryRunError(ValueError):
+    pass
+
 class DatabaseOperations(object):
-
     """
     Generic SQL implementation of the DatabaseOperations.
     Some of this code comes from Django Evolution.
     create_primary_key_string = "ALTER TABLE %(table)s ADD CONSTRAINT %(constraint)s PRIMARY KEY (%(columns)s)"
     delete_primary_key_sql = "ALTER TABLE %(table)s DROP CONSTRAINT %(constraint)s"
     add_check_constraint_fragment = "ADD CONSTRAINT %(constraint)s CHECK (%(check)s)"
+    rename_table_sql = "ALTER TABLE %s RENAME TO %s;"
     backend_name = None
     default_schema_name = "public"
 
         if self.debug:
             print "   = %s" % sql, params
 
-        get_logger().debug('south execute "%s" with params "%s"' % (sql, params))
-
         if self.dry_run:
             return []
 
+        get_logger().debug('execute "%s" with params "%s"' % (sql, params))
+
         try:
             cursor.execute(sql, params)
         except DatabaseError, e:
             # Short-circuit out.
             return
         params = (self.quote_name(old_table_name), self.quote_name(table_name))
-        self.execute('ALTER TABLE %s RENAME TO %s;' % params)
+        self.execute(self.rename_table_sql % params)
+        # Invalidate the not-yet-indexed table
+        self._set_cache(table_name, value=INVALID)
 
 
     @invalidate_table_constraints
         If columns is None, returns all constraints of the type on the table.
         """
         if self.dry_run:
-            raise ValueError("Cannot get constraints for columns during a dry run.")
+            raise DryRunError("Cannot get constraints for columns.")
 
         if columns is not None:
             columns = set(map(lambda s: s.lower(), columns))
             if self.debug:
                 print '   - no dry run output for delete_foreign_key() due to dynamic DDL, sorry'
             return # We can't look at the DB to get the constraints
-        constraints = list(self._constraints_affecting_columns(table_name, [column], "FOREIGN KEY"))
+        constraints = self._find_foreign_constraints(table_name, column)
         if not constraints:
             raise ValueError("Cannot find a FOREIGN KEY constraint on table %s, column %s" % (table_name, column))
         for constraint_name in constraints:
 
     drop_foreign_key = alias('delete_foreign_key')
 
+    def _find_foreign_constraints(self, table_name, column_name=None):
+        return list(self._constraints_affecting_columns(
+                    table_name, [column_name], "FOREIGN KEY"))
 
     def create_index_name(self, table_name, column_names, suffix=""):
         """

File south/db/mysql.py

+# MySQL-specific implementations for south
+# Original author: Andrew Godwin
+# Patches by: F. Gabriel Gosselin <gabrielNOSPAM@evidens.ca>
 
 from django.db import connection
 from django.conf import settings
 from south.db import generic
+from south.db.generic import DryRunError, INVALID
+
+from south.logger import get_logger
+
+def delete_column_constraints(func):
+    """
+    Decorates column operation functions for MySQL.
+    Deletes the constraints from the database and clears local cache.
+    """
+    def _column_rm(self, table_name, column_name, *args, **opts):
+        # Delete foreign key constraints
+        try:
+            self.delete_foreign_key(table_name, column_name)
+        except ValueError:
+            pass # If no foreign key on column, OK because it checks first
+        # Delete constraints referring to this column
+        try:
+            reverse = self._lookup_reverse_constraint(table_name, column_name)
+            for cname, rtable, rcolumn in reverse:
+                self.delete_foreign_key(rtable, rcolumn)
+        except DryRunError:
+            pass
+        return func(self, table_name, column_name, *args, **opts)
+    return _column_rm
+
+def copy_column_constraints(func):
+    """
+    Decorates column operation functions for MySQL.
+    Determines existing constraints and copies them to a new column
+    """
+    def _column_cp(self, table_name, column_old, column_new, *args, **opts):
+        # Copy foreign key constraint
+        try:
+            constraint = self._find_foreign_constraints(table_name, column_old)[0]
+            (ftable, fcolumn) = self._lookup_constraint_references(table_name, constraint)
+            if ftable and fcolumn:
+                fk_sql = self.foreign_key_sql(
+                            table_name, column_new, ftable, fcolumn)
+                get_logger().debug("Foreign key SQL: " + fk_sql)
+                self.add_deferred_sql(fk_sql)
+        except IndexError:
+            pass # No constraint exists so ignore
+        except DryRunError:
+            pass
+        # Copy constraints referring to this column
+        try:
+            reverse = self._lookup_reverse_constraint(table_name, column_old)
+            for cname, rtable, rcolumn in reverse:
+                fk_sql = self.foreign_key_sql(
+                        rtable, rcolumn, table_name, column_new)
+                self.add_deferred_sql(fk_sql)
+        except DryRunError:
+            pass
+        return func(self, table_name, column_old, column_new, *args, **opts)
+    return _column_cp
+
+def invalidate_table_constraints(func):
+    """
+    For MySQL we grab all table constraints simultaneously, so this is
+    effective.
+    It further solves the issues of invalidating referred table constraints.
+    """
+    def _cache_clear(self, table, *args, **opts):
+        db_name = self._get_setting('NAME')
+        if db_name in self._constraint_cache:
+            del self._constraint_cache[db_name]
+        if db_name in self._reverse_cache:
+            del self._reverse_cache[db_name]
+        if db_name in self._constraint_references:
+            del self._constraint_references[db_name]
+        return func(self, table, *args, **opts)
+    return _cache_clear
 
 class DatabaseOperations(generic.DatabaseOperations):
-
     """
     MySQL implementation of database operations.
-    
-    MySQL is an 'interesting' database; it has no DDL transaction support,
-    among other things. This can confuse people when they ask how they can
-    roll back - hence the dry runs, etc., found in the migration code.
-    Alex agrees, and Alex is always right.
-    [19:06] <Alex_Gaynor> Also, I want to restate once again that MySQL is a special database
-    
-    (Still, if you want a key-value store with relational tendancies, go MySQL!)
+
+    MySQL has no DDL transaction support This can confuse people when they ask
+    how to roll back - hence the dry runs, etc., found in the migration code.
     """
-    
+
     backend_name = "mysql"
     alter_string_set_type = ''
     alter_string_set_null = 'MODIFY %(column)s %(type)s NULL;'
     has_ddl_transactions = False
     has_check_constraints = False
     delete_unique_sql = "ALTER TABLE %s DROP INDEX %s"
+    rename_table_sql = "RENAME TABLE %s TO %s;"
 
     geom_types = ['geometry', 'point', 'linestring', 'polygon']
     text_types = ['text', 'blob',]
 
+    def __init__(self, db_alias):
+        self._constraint_references = {}
+        self._reverse_cache = {}
+        super(DatabaseOperations, self).__init__(db_alias)
+
     def _is_valid_cache(self, db_name, table_name):
         cache = self._constraint_cache
         # we cache the whole db so if there are any tables table_name is valid
-        return db_name in cache and cache[db_name].get(table_name, None) is not generic.INVALID
+        return db_name in cache and cache[db_name].get(table_name, None) is not INVALID
 
     def _fill_constraint_cache(self, db_name, table_name):
         # for MySQL grab all constraints for this database.  It's just as cheap as a single column.
         self._constraint_cache[db_name] = {}
         self._constraint_cache[db_name][table_name] = {}
+        self._reverse_cache[db_name] = {}
+        self._constraint_references[db_name] = {}
 
         name_query = """
-            SELECT kc.constraint_name, kc.column_name, kc.table_name
+            SELECT kc.`constraint_name`, kc.`column_name`, kc.`table_name`,
+                kc.`referenced_table_name`, kc.`referenced_column_name`
             FROM information_schema.key_column_usage AS kc
             WHERE
                 kc.table_schema = %s
         if not rows:
             return
         cnames = {}
-        for constraint, column, table in rows:
+        for constraint, column, table, ref_table, ref_column in rows:
             key = (table, constraint)
             cnames.setdefault(key, set())
-            cnames[key].add(column)
+            cnames[key].add((column, ref_table, ref_column))
 
         type_query = """
             SELECT c.constraint_name, c.table_name, c.constraint_type
                 cols = cnames[key]
             except KeyError:
                 cols = set()
-            for column in cols:
+            for column_set in cols:
+                (column, ref_table, ref_column) = column_set
                 self._constraint_cache[db_name][table].setdefault(column, set())
-                self._constraint_cache[db_name][table][column].add((kind, constraint))
-
+                if kind == 'FOREIGN KEY':
+                    self._constraint_cache[db_name][table][column].add((kind,
+                        constraint))
+                    # Create constraint lookup, see constraint_references
+                    self._constraint_references[db_name][(table,
+                        constraint)] = (ref_table, ref_column)
+                    # Create reverse table lookup, reverse_lookup
+                    self._reverse_cache[db_name].setdefault(ref_table, {})
+                    self._reverse_cache[db_name][ref_table].setdefault(ref_column,
+                            set())
+                    self._reverse_cache[db_name][ref_table][ref_column].add(
+                            (constraint, table, column))
+                else:
+                    self._constraint_cache[db_name][table][column].add((kind,
+                    constraint))
 
     def connection_init(self):
         """
         cursor.execute("SET FOREIGN_KEY_CHECKS=0;")
         self.deferred_sql.append("SET FOREIGN_KEY_CHECKS=1;")
 
-    @generic.copy_column_constraints
-    @generic.delete_column_constraints
+    @copy_column_constraints
+    @delete_column_constraints
+    @invalidate_table_constraints
     def rename_column(self, table_name, old, new):
         if old == new or self.dry_run:
             return []
-        
+
         rows = [x for x in self.execute('DESCRIBE %s' % (self.quote_name(table_name),)) if x[0] == old]
-        
+
         if not rows:
             raise ValueError("No column '%s' in '%s'." % (old, table_name))
-        
+
         params = (
             self.quote_name(table_name),
             self.quote_name(old),
             rows[0][4] and "%s" or "",
             rows[0][5] or "",
         )
-        
+
         sql = 'ALTER TABLE %s CHANGE COLUMN %s %s %s %s %s %s %s;' % params
-        
+
         if rows[0][4]:
             self.execute(sql, (rows[0][4],))
         else:
             self.execute(sql)
 
+    @delete_column_constraints
     def delete_column(self, table_name, name):
-        db_name = self._get_setting('NAME')
-
-        # See if there is a foreign key on this column
-        result = 0
-        for kind, cname in self.lookup_constraint(db_name, table_name, name):
-            if kind == 'FOREIGN KEY':
-                result += 1
-                fkey_name = cname
-        if result:
-            assert result == 1 # We should only have one result, otherwise there's Issues
-            cursor = self._get_connection().cursor()
-            drop_query = "ALTER TABLE %s DROP FOREIGN KEY %s"
-            self.execute(drop_query % (self.quote_name(table_name), self.quote_name(fkey_name)))
-
         super(DatabaseOperations, self).delete_column(table_name, name)
 
-    @generic.invalidate_table_constraints
+    @invalidate_table_constraints
     def rename_table(self, old_table_name, table_name):
+        super(DatabaseOperations, self).rename_table(old_table_name,
+                table_name)
+
+    @invalidate_table_constraints
+    def delete_table(self, table_name):
+        super(DatabaseOperations, self).delete_table(table_name)
+
+    def _lookup_constraint_references(self, table_name, cname):
         """
-        Renames the table 'old_table_name' to 'table_name'.
+        Provided an existing table and constraint, returns tuple of (foreign
+        table, column)
         """
-        if old_table_name == table_name:
-            # No Operation
-            return
-        params = (self.quote_name(old_table_name), self.quote_name(table_name))
-        self.execute('RENAME TABLE %s TO %s;' % params)
+        db_name = self._get_setting('NAME')
+        try:
+            return self._constraint_references[db_name][(table_name, cname)]
+        except KeyError:
+            return None
+
+    def _lookup_reverse_constraint(self, table_name, column_name=None):
+        """Look for the column referenced by a foreign constraint"""
+        db_name = self._get_setting('NAME')
+        if self.dry_run:
+            raise DryRunError("Cannot get constraints for columns.")
+
+        if not self._is_valid_cache(db_name, table_name):
+            # Piggy-back on lookup_constraint, ensures cache exists
+            self.lookup_constraint(db_name, table_name)
+
+        try:
+            table = self._reverse_cache[db_name][table_name]
+            if column_name == None:
+                return [(y, tuple(y)) for x, y in table.items()]
+            else:
+                return tuple(table[column_name])
+        except KeyError, e:
+            return []
 
     def _field_sanity(self, field):
         """
         if is_geom or is_text:
             field._suppress_default = True
         return field
-    
-    
+
     def _alter_set_defaults(self, field, name, params, sqls):
         """
         MySQL does not support defaults on text or blob columns.
         is_text = True in [ type.find(t) > -1 for t in self.text_types ]
         if not is_geom and not is_text:
             super(DatabaseOperations, self)._alter_set_defaults(field, name, params, sqls)
+

File south/tests/__init__.py

 
 if not skiptest:
     from south.tests.db import *
+    from south.tests.db_mysql import *
     from south.tests.logic import *
     from south.tests.autodetection import *
     from south.tests.logger import *

File south/tests/db_mysql.py

+# Additional MySQL-specific tests
+# Written by: F. Gabriel Gosselin <gabrielNOSPAM@evidens.ca>
+# Based on tests by: aarranz
+import unittest
+
+from south.db import db, generic, mysql
+from django.db import connection, models
+
+
+class TestMySQLOperations(unittest.TestCase):
+    """MySQL-specific tests"""
+    def setUp(self):
+        db.debug = False
+        db.clear_deferred_sql()
+
+    def tearDown(self):
+        pass
+
+    def _create_foreign_tables(self, main_name, reference_name):
+        # Create foreign table and model
+        Foreign = db.mock_model(model_name='Foreign', db_table=reference_name,
+                                db_tablespace='', pk_field_name='id',
+                                pk_field_type=models.AutoField,
+                                pk_field_args=[])
+        db.create_table(reference_name, [
+                ('id', models.AutoField(primary_key=True)),
+            ])
+        # Create table with foreign key
+        db.create_table(main_name, [
+                ('id', models.AutoField(primary_key=True)),
+                ('foreign', models.ForeignKey(Foreign)),
+            ])
+        return Foreign
+
+    def test_constraint_references(self):
+        """Tests that referred table is reported accurately"""
+        main_table = 'test_cns_ref'
+        reference_table = 'test_cr_foreign'
+        db.start_transaction()
+        self._create_foreign_tables(main_table, reference_table)
+        db.execute_deferred_sql()
+        constraint = db._find_foreign_constraints(main_table, 'foreign_id')[0]
+        constraint_name = 'foreign_id_refs_id_%x' % (abs(hash((main_table,
+            reference_table))))
+        self.assertEquals(constraint_name, constraint)
+        references = db._lookup_constraint_references(main_table, constraint)
+        self.assertEquals((reference_table, 'id'), references)
+        db.delete_table(main_table)
+        db.delete_table(reference_table)
+
+    def test_reverse_column_constraint(self):
+        """Tests that referred column in a foreign key (ex. id) is found"""
+        main_table = 'test_reverse_ref'
+        reference_table = 'test_rr_foreign'
+        db.start_transaction()
+        self._create_foreign_tables(main_table, reference_table)
+        db.execute_deferred_sql()
+        inverse = db._lookup_reverse_constraint(reference_table, 'id')
+        (cname, rev_table, rev_column) = inverse[0]
+        self.assertEquals(main_table, rev_table)
+        self.assertEquals('foreign_id', rev_column)
+        db.delete_table(main_table)
+        db.delete_table(reference_table)
+
+    def test_delete_fk_column(self):
+        main_table = 'test_drop_foreign'
+        ref_table = 'test_df_ref'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        db.delete_column(main_table, 'foreign_id')
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 0)
+        db.delete_table(main_table)
+        db.delete_table(ref_table)
+
+    def test_rename_fk_column(self):
+        main_table = 'test_rename_foreign'
+        ref_table = 'test_rf_ref'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        db.rename_column(main_table, 'foreign_id', 'reference_id')
+        db.execute_deferred_sql()  #Create constraints
+        constraints = db._find_foreign_constraints(main_table, 'reference_id')
+        self.assertEquals(len(constraints), 1)
+        db.delete_table(main_table)
+        db.delete_table(ref_table)
+
+    def test_rename_fk_inbound(self):
+        """
+        Tests that the column referred to by an external column can be renamed.
+        Edge case, but also useful as stepping stone to renaming tables.
+        """
+        main_table = 'test_rename_fk_inbound'
+        ref_table = 'test_rfi_ref'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._lookup_reverse_constraint(ref_table, 'id')
+        self.assertEquals(len(constraints), 1)
+        db.rename_column(ref_table, 'id', 'rfi_id')
+        db.execute_deferred_sql()  #Create constraints
+        constraints = db._lookup_reverse_constraint(ref_table, 'rfi_id')
+        self.assertEquals(len(constraints), 1)
+        cname = db._find_foreign_constraints(main_table, 'foreign_id')[0]
+        (rtable, rcolumn) = db._lookup_constraint_references(main_table, cname)
+        self.assertEquals(rcolumn, 'rfi_id')
+        db.delete_table(main_table)
+        db.delete_table(ref_table)
+
+    def test_rename_constrained_table(self):
+        """Renames a table with a foreign key column (towards another table)"""
+        main_table = 'test_rn_table'
+        ref_table = 'test_rt_ref'
+        renamed_table = 'test_renamed_table'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        db.rename_table(main_table, renamed_table)
+        db.execute_deferred_sql()  #Create constraints
+        constraints = db._find_foreign_constraints(renamed_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        (rtable, rcolumn) = db._lookup_constraint_references(
+                renamed_table, constraints[0])
+        self.assertEquals(rcolumn, 'id')
+        db.delete_table(renamed_table)
+        db.delete_table(ref_table)
+
+    def test_renamed_referenced_table(self):
+        """Rename a table referred to in a foreign key"""
+        main_table = 'test_rn_refd_table'
+        ref_table = 'test_rrt_ref'
+        renamed_table = 'test_renamed_ref'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._lookup_reverse_constraint(ref_table)
+        self.assertEquals(len(constraints), 1)
+        db.rename_table(ref_table, renamed_table)
+        db.execute_deferred_sql()  #Create constraints
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        (rtable, rcolumn) = db._lookup_constraint_references(
+                main_table, constraints[0])
+        self.assertEquals(renamed_table, rtable)
+        db.delete_table(main_table)
+        db.delete_table(renamed_table)
+