Commits

F. Gabriel Gosselin committed e32fc7f

Implementing foreign key column rename
Hoisted DryRunError to generic to enable wider use for the copy_column_constraints decorator

  • Participants
  • Parent commits 7bfc69a

Comments (0)

Files changed (3)

File south/db/generic.py

     def __repr__(self):
         return 'INVALID'
 
+class DryRunError(ValueError):
+    pass
+
 class DatabaseOperations(object):
     """
     Generic SQL implementation of the DatabaseOperations.
         If columns is None, returns all constraints of the type on the table.
         """
         if self.dry_run:
-            raise ValueError("Cannot get constraints for columns during a dry run.")
+            raise DryRunError("Cannot get constraints for columns.")
 
         if columns is not None:
             columns = set(map(lambda s: s.lower(), columns))

File south/db/mysql.py

 # MySQL-specific implementations for south
 # Original author: Andrew Godwin
 # Patches by: F. Gabriel Gosselin <gabrielNOSPAM@evidens.ca>
-#             aarranz
 
 from django.db import connection
 from django.conf import settings
 from south.db import generic
+from south.db.generic import DryRunError, INVALID
 
 from south.logger import get_logger
 
         return func(self, table_name, column_name, *args, **opts)
     return _column_rm
 
+def copy_column_constraints(func):
+    """
+    Decorates column operation functions for MySQL.
+    Determines existing constraints and copies them to a new column
+    """
+    def _column_cp(self, table_name, column_old, column_new, *args, **opts):
+        try:
+            constraint = self._find_foreign_constraints(table_name, column_old)[0]
+            (rtable, rcolumn) = self._lookup_constraint_references(table_name, constraint)
+            if rtable and rcolumn:
+                fk_sql = self.foreign_key_sql(
+                            table_name, column_new, rtable, rcolumn)
+                get_logger().debug("Foreign key SQL: " + fk_sql)
+                self.add_deferred_sql(fk_sql)
+        except IndexError:
+            pass # No constraint exists so ignore
+        except DryRunError:
+            pass
+        return func(self, table_name, column_old, column_new, *args, **opts)
+    return _column_cp
+
+
 class DatabaseOperations(generic.DatabaseOperations):
     """
     MySQL implementation of database operations.
     def _is_valid_cache(self, db_name, table_name):
         cache = self._constraint_cache
         # we cache the whole db so if there are any tables table_name is valid
-        return db_name in cache and cache[db_name].get(table_name, None) is not generic.INVALID
+        return db_name in cache and cache[db_name].get(table_name, None) is not INVALID
 
     def _fill_constraint_cache(self, db_name, table_name):
         # for MySQL grab all constraints for this database.  It's just as cheap as a single column.
         cursor.execute("SET FOREIGN_KEY_CHECKS=0;")
         self.deferred_sql.append("SET FOREIGN_KEY_CHECKS=1;")
 
-    @generic.copy_column_constraints
-    @generic.delete_column_constraints
+    @copy_column_constraints
+    @delete_column_constraints
+    @generic.invalidate_table_constraints
     def rename_column(self, table_name, old, new):
         if old == new or self.dry_run:
             return []
         if not is_geom and not is_text:
             super(DatabaseOperations, self)._alter_set_defaults(field, name, params, sqls)
 
-class DryRunError(ValueError):
-    pass

File south/tests/db_mysql.py

         db.delete_column(main_table, 'foreign_id')
         constraints = db._find_foreign_constraints(main_table, 'foreign_id')
         self.assertEquals(len(constraints), 0)
+        db.delete_table(main_table)
+        db.delete_table(ref_table)
 
+    def test_rename_fk_column(self):
+        main_table = 'test_rename_foreign'
+        ref_table = 'test_rf_ref'
+        self._create_foreign_tables(main_table, ref_table)
+        db.execute_deferred_sql()
+        constraints = db._find_foreign_constraints(main_table, 'foreign_id')
+        self.assertEquals(len(constraints), 1)
+        db.rename_column(main_table, 'foreign_id', 'reference_id')
+        db.execute_deferred_sql()  #Create constraints
+        constraints = db._find_foreign_constraints(main_table, 'reference_id')
+        self.assertEquals(len(constraints), 1)
+        db.delete_table(main_table)
+        db.delete_table(ref_table)