Commits

Andrew Godwin committed 4f86714

Fix #183 (Changing FKs doesn't take care of the constraints)

Comments (0)

Files changed (5)

south/db/generic.py

     allows_combined_alters = True
     add_column_string = 'ALTER TABLE %s ADD COLUMN %s;'
     delete_unique_sql = "ALTER TABLE %s DROP CONSTRAINT %s"
-    delete_foreign_key_sql = 'ALTER TABLE %s DROP CONSTRAINT %s'
+    delete_foreign_key_sql = 'ALTER TABLE %(table)s DROP CONSTRAINT %(constraint)s'
     supports_foreign_keys = True
     max_index_name_length = 63
     drop_index_string = 'DROP INDEX %(index_name)s'
         To be overriden by backend specific subclasses
         @param field: The field to generate type for
         """
-        return field.db_type()
+        try:
+            return field.db_type(connection=self._get_connection())
+        except TypeError:
+            return field.db_type()
 
     def alter_column(self, table_name, name, field, explicit_name=True):
         """
                     'table': self.quote_name(table_name),
                     'constraint': self.quote_name(constraint),
                 })
+        
+        # Drop all foreign key constraints
+        try:
+            self.delete_foreign_key(table_name, name)
+        except ValueError:
+            # There weren't any
+            pass
 
         # First, change the type
         params = {
 
 
         # Next, nullity
-        params = {
-            "column": self.quote_name(name),
-            "type": field.db_type(),
-        }
         if field.null:
             sqls.append((self.alter_string_set_null % params, []))
         else:
             sqls.append((self.alter_string_drop_null % params, []))
 
-        # TODO: Unique
-
+        # Finally, actually change the column
         if self.allows_combined_alters:
             sqls, values = zip(*sqls)
             self.execute(
             # Databases like e.g. MySQL don't like more than one alter at once.
             for sql, values in sqls:
                 self.execute("ALTER TABLE %s %s;" % (self.quote_name(table_name), sql), values)
+        
+        # Add back FK constraints if needed
+        if field.rel and self.supports_foreign_keys:
+            self.execute(
+                self.foreign_key_sql(
+                    table_name,
+                    field.column,
+                    field.rel.to._meta.db_table,
+                    field.rel.to._meta.get_field(field.rel.field_name).column
+                )
+            )
 
 
     def _constraints_affecting_columns(self, table_name, columns, type="UNIQUE"):
         # Possible hook to fiddle with the fields (e.g. defaults & TEXT on MySQL)
         field = self._field_sanity(field)
 
-        sql = field.db_type()
+        try:
+            sql = field.db_type(connection=self._get_connection())
+        except TypeError:
+            sql = field.db_type()
+        
         if sql:        
             field_output = [self.quote_name(field.column), sql]
             field_output.append('%sNULL' % (not field.null and 'NOT ' or ''))
             self.quote_name(to_column_name),
             self._get_connection().ops.deferrable_sql() # Django knows this
         )
-
+    
 
     def delete_foreign_key(self, table_name, column):
         "Drop a foreign key constraint"
         if not constraints:
             raise ValueError("Cannot find a FOREIGN KEY constraint on table %s, column %s" % (table_name, column))
         for constraint_name in constraints:
-            self.execute(self.delete_foreign_key_sql % (
-                self.quote_name(table_name),
-                self.quote_name(constraint_name),
-            ))
+            self.execute(self.delete_foreign_key_sql % {
+                "table": self.quote_name(table_name),
+                "constraint": self.quote_name(constraint_name),
+            })
 
     drop_foreign_key = alias('delete_foreign_key')
 

south/db/mysql.py

     alter_string_drop_null = 'MODIFY %(column)s %(type)s NOT NULL;'
     drop_index_string = 'DROP INDEX %(index_name)s ON %(table_name)s'
     delete_primary_key_sql = "ALTER TABLE %(table)s DROP PRIMARY KEY"
+    delete_foreign_key_sql = "ALTER TABLE %(table)s DROP FOREIGN KEY %(constraint)s"
     allows_combined_alters = False
     has_ddl_transactions = False
     has_check_constraints = False
         """
         This particular override stops us sending DEFAULTs for BLOB/TEXT columns.
         """
-        if field.db_type().upper() in ["BLOB", "TEXT", "LONGTEXT"]:
+        if self._db_type_for_alter_column(field).upper() in ["BLOB", "TEXT", "LONGTEXT"]:
             field._suppress_default = True
         return field

south/db/postgresql_psycopg2.py

         Strips CHECKs from PositiveSmallIntegerField) and PositiveIntegerField
         @param field: The field to generate type for
         """
+        super_result = super(DatabaseOperations, self)._db_type_for_alter_column(field)
         if isinstance(field, models.PositiveSmallIntegerField) or isinstance(field, models.PositiveIntegerField):
-            return field.db_type().split(" ")[0]
-        return super(DatabaseOperations, self)._db_type_for_alter_column(field)
+            return super_result.split(" ")[0]
+        return super_result

south/migration/__init__.py

     app_label = migrations.app_label()
 
     verbosity = int(verbosity)
-    db.debug = (verbosity > 1)
     # Fire off the pre-migrate signal
     pre_migrate.send(None, app=app_label)
     
         # We now have to make sure the migrations are all reloaded, as they'll
         # have imported the old value of south.db.db.
         Migrations.invalidate_all_modules()
+    
+    south.db.db.debug = (verbosity > 1)
     applied = check_migration_histories(applied, delete_ghosts)
     
     # Guess the target_name
     migrator = get_migrator(direction, db_dry_run, fake, load_initial_data)
     if migrator:
         migrator.print_title(target)
-        success = migrator.migrate_many(target, workplan)
+        success = migrator.migrate_many(target, workplan, database)
         # Finally, fire off the post-migrate signal
         if success:
             post_migrate.send(None, app=app_label)

south/migration/migrators.py

 from django.core.management.commands import loaddata
 from django.db import models
 
+import south.db
 from south import exceptions
-from south.db import db
+from south.db import DEFAULT_DB_ALIAS
 from south.models import MigrationHistory
 from south.signals import ran_migration
 
         return (lambda: direction(orm))
 
     @staticmethod
-    def record(migration):
+    def record(migration, database):
         raise NotImplementedError()
 
     def run_migration_error(self, migration, extra_info=''):
 
     def run_migration(self, migration):
         migration_function = self.direction(migration)
-        db.start_transaction()
+        south.db.db.start_transaction()
         try:
             migration_function()
-            db.execute_deferred_sql()
+            south.db.db.execute_deferred_sql()
         except:
-            db.rollback_transaction()
-            if not db.has_ddl_transactions:
+            south.db.db.rollback_transaction()
+            if not south.db.db.has_ddl_transactions:
                 print self.run_migration_error(migration)
             raise
         else:
-            db.commit_transaction()
+            south.db.db.commit_transaction()
 
     def run(self, migration):
         # Get the correct ORM.
-        db.current_orm = self.orm(migration)
+        south.db.db.current_orm = self.orm(migration)
         # If the database doesn't support running DDL inside a transaction
         # *cough*MySQL*cough* then do a dry run first.
-        if not db.has_ddl_transactions:
+        if not south.db.db.has_ddl_transactions:
             dry_run = DryRunMigrator(migrator=self, ignore_fail=False)
             dry_run.run_migration(migration)
         return self.run_migration(migration)
 
-    def done_migrate(self, migration):
-        db.start_transaction()
+    def done_migrate(self, migration, database):
+        south.db.db.start_transaction()
         try:
             # Record us as having done this
-            self.record(migration)
+            self.record(migration, database)
         except:
-            db.rollback_transaction()
+            south.db.db.rollback_transaction()
             raise
         else:
-            db.commit_transaction()
+            south.db.db.commit_transaction()
 
     def send_ran_migration(self, migration):
         ran_migration.send(None,
                            migration=migration,
                            method=self.__class__.__name__.lower())
 
-    def migrate(self, migration):
+    def migrate(self, migration, database):
         """
         Runs the specified migration forwards/backwards, in order.
         """
         migration_name = migration.name()
         self.print_status(migration)
         result = self.run(migration)
-        self.done_migrate(migration)
+        self.done_migrate(migration, database)
         self.send_ran_migration(migration)
         return result
 
-    def migrate_many(self, target, migrations):
+    def migrate_many(self, target, migrations, database):
         raise NotImplementedError()
 
 
         if migration.no_dry_run() and self.verbosity:
             print " - Migration '%s' is marked for no-dry-run." % migration
             return
-        db.dry_run = True
+        south.db.db.dry_run = True
         if self._ignore_fail:
-            db.debug, old_debug = False, db.debug
-        pending_creates = db.get_pending_creates()
-        db.start_transaction()
+            south.db.db.debug, old_debug = False, south.db.db.debug
+        pending_creates = south.db.db.get_pending_creates()
+        south.db.db.start_transaction()
         migration_function = self.direction(migration)
         try:
             try:
                 migration_function()
-                db.execute_deferred_sql()
+                south.db.db.execute_deferred_sql()
             except:
                 raise exceptions.FailedDryRun(migration, sys.exc_info())
         finally:
-            db.rollback_transactions_dry_run()
+            south.db.db.rollback_transactions_dry_run()
             if self._ignore_fail:
-                db.debug = old_debug
-            db.clear_run_data(pending_creates)
-            db.dry_run = False
+                south.db.db.debug = old_debug
+            south.db.db.clear_run_data(pending_creates)
+            south.db.db.dry_run = False
 
     def run_migration(self, migration):
         try:
             models.get_apps = old_get_apps
             loaddata.get_apps = old_get_apps
 
-    def migrate_many(self, target, migrations):
+    def migrate_many(self, target, migrations, database):
         migrator = self._migrator
-        result = migrator.__class__.migrate_many(migrator, target, migrations)
+        result = migrator.__class__.migrate_many(migrator, target, migrations, database)
         if result:
             self.load_initial_data(target)
         return True
     direction = forwards
 
     @staticmethod
-    def record(migration):
+    def record(migration, database):
         # Record us as having done this
         record = MigrationHistory.for_migration(migration)
         record.applied = datetime.datetime.utcnow()
-        record.save()
+        if database != DEFAULT_DB_ALIAS:
+            record.save(using=database)
+        else:
+            # Django 1.1 and below always go down this branch.
+            record.save()
 
     def format_backwards(self, migration):
-        old_debug, old_dry_run = db.debug, db.dry_run
-        db.debug = db.dry_run = True
+        old_debug, old_dry_run = south.db.db.debug, south.db.db.dry_run
+        south.db.db.debug = south.db.db.dry_run = True
         stdout = sys.stdout
         sys.stdout = StringIO()
         try:
             except:
                 raise
         finally:
-            db.debug, db.dry_run = old_debug, old_dry_run
+            south.db.db.debug, south.db.db.dry_run = old_debug, old_dry_run
             sys.stdout = stdout
 
     def run_migration_error(self, migration, extra_info=''):
                       (self.format_backwards(migration), extra_info))
         return super(Forwards, self).run_migration_error(migration, extra_info)
 
-    def migrate_many(self, target, migrations):
+    def migrate_many(self, target, migrations, database):
         try:
             for migration in migrations:
-                result = self.migrate(migration)
+                result = self.migrate(migration, database)
                 if result is False: # The migrations errored, but nicely.
                     return False
         finally:
             # Call any pending post_syncdb signals
-            db.send_pending_create_signals()
+            south.db.db.send_pending_create_signals()
         return True
 
 
     direction = Migrator.backwards
 
     @staticmethod
-    def record(migration):
+    def record(migration, database):
         # Record us as having not done this
         record = MigrationHistory.for_migration(migration)
         if record.id is not None:
-            record.delete()
+            if database != DEFAULT_DB_ALIAS:
+                record.delete(using=database)
+            else:
+                # Django 1.1 always goes down here
+                record.delete()
 
-    def migrate_many(self, target, migrations):
+    def migrate_many(self, target, migrations, database):
         for migration in migrations:
-            self.migrate(migration)
+            self.migrate(migration, database)
         return True