Commits

Simon Law committed 7769eaf

Split up Migrator.migrate() completely.

Comments (0)

Files changed (3)

         return "Migration '%(migration)s' depends on unmigrated application '%(application)s'." % self.__dict__
 
 
+class FailedDryRun(SouthError):
+    def __init__(self, migration, exc_info):
+        self.migration = migration
+        self.name = migration.name()
+        self.exc_info = exc_info
+        self.traceback = ''.join(format_exception(*self.exc_info))
+
+    def __str__(self):
+        return (" ! Error found during dry run of '%(name)s'! Aborting.\n"
+                "%(traceback)s") % self.__dict__

south/migration/__init__.py

         if self.verbosity and status:
             print status
 
-    def run(self, migration):
-        # Get migration class
-        klass = migration.migration().Migration
-        # OK, we should probably do something then.
-        runfunc = getattr(klass(), self.torun)
-        args = inspect.getargspec(runfunc)
-        # Get the correct ORM.
-        if self.torun == 'forwards':
-            orm = migration.orm()
-        else:
-            orm = migration.prev_orm()
-        db.current_orm = orm
-        # If the database doesn't support running DDL inside a transaction
-        # *cough*MySQL*cough* then do a dry run first.
-        if not db.has_ddl_transactions or self.db_dry_run:
-            if not (hasattr(klass, 'no_dry_run') and klass.no_dry_run):
-                db.dry_run = True
-                db.debug, old_debug = False, db.debug
-                pending_creates = db.get_pending_creates()
-                db.start_transaction()
-                try:
-                    if len(args[0]) == 1:  # They don't want an ORM param
-                        runfunc()
-                    else:
-                        runfunc(orm)
-                        db.rollback_transactions_dry_run()
-                except:
-                    traceback.print_exc()
-                    print ' ! Error found during dry run of migration! Aborting.'
-                    return False
-                db.debug = old_debug
-                db.clear_run_data(pending_creates)
-                db.dry_run = False
-            elif db_dry_run:
-                print " - Migration '%s' is marked for no-dry-run."
-            # If they really wanted to dry-run, then quit!
-            if self.db_dry_run:
-                return
-        # Run the migration
+    def orm(self, migration):
+        raise NotImplementedError()
+
+    def backwards(self, migration):
+        return self._wrap_direction(migration.backwards(), self.orm(migration))
+
+    def direction(self, migration):
+        raise NotImplementedError()
+
+    def print_backwards(self, migration):
+        old_debug, old_dry_run = db.debug, db.dry_run
+        db.debug = db.dry_run = True
+        try:
+            self.backwards(migration)()
+        except:
+            db.debug, db.dry_run = old_debug, old_dry_run
+            raise
+
+    @staticmethod
+    def _wrap_direction(direction, orm):
+        args = inspect.getargspec(direction)
+        if len(args[0]) == 1:
+            # Old migration, no ORM should be passed in
+            return direction
+        return (lambda: direction(orm))
+
+    def dry_run_migration(self, migration):
+        if migration.no_dry_run() and self.verbosity:
+            print " - Migration '%s' is marked for no-dry-run."
+            return
+        db.dry_run = True
+        db.debug, old_debug = False, db.debug
+        pending_creates = db.get_pending_creates()
+        db.start_transaction()
+        migration_function = self.direction(migration)
+        try:
+            migration_function()
+        except:
+            raise exceptions.FailedDryRun(sys.exc_info())
+        finally:
+            db.rollback_transactions_dry_run()
+        db.debug = old_debug
+        db.clear_run_data(pending_creates)
+        db.dry_run = False
+
+    def run_migration(self, migration):
+        migration_function = self.direction(migration)
         db.start_transaction()
         try:
-            if len(args[0]) == 1:  # They don't want an ORM param
-                runfunc()
-            else:
-                runfunc(orm)
+            migration_function()
             db.execute_deferred_sql()
         except:
             db.rollback_transaction()
             if not db.has_ddl_transactions:
-                traceback.print_exc()
                 print ' ! Error found during real run of migration! Aborting.'
                 print
                 print ' ! Since you have a database that does not support running'
                 if self.torun == 'forwards':
                     print
                     print " ! You *might* be able to recover with:"
-                    db.debug = db.dry_run = True
-                    if len(args[0]) == 1:
-                        klass().backwards()
-                    else:
-                        klass().backwards(migration.prev_orm())
+                    self.print_backwards(migration)
                 print
                 print ' ! The South developers regret this has happened, and would'
                 print ' ! like to gently persuade you to consider a slightly'
         else:
             db.commit_transaction()
 
+    def run(self, migration):
+        # Get the correct ORM.
+        db.current_orm = self.orm(migration)
+        # Handle dry runs
+        if self.db_dry_run:
+            try:
+                self.dry_run_migration(migration)
+            except:
+                return False
+            return
+        # If the database doesn't support running DDL inside a transaction
+        # *cough*MySQL*cough* then do a dry run first.
+        if not db.has_ddl_transactions:
+            try:
+                self.dry_run_migration(migration)
+            except:
+                return False
+        self.run_migration(migration)
+
     def done_migrate(self, migration):
         if not self.db_dry_run:
             db.start_transaction()
             ran_migration.send(None,
                                app=migration.app_name(),
                                migration=migration,
-                               method=self.torun)
+                               method=self.__class__.__name__.lower())
 
     def migrate(self, migration):
         """
             # If this is a 'fake' migration, do nothing.
             if self.verbosity:
                 print '   (faked)'
+            result = None
         else:
-            self.run(migration)
+            result = self.run(migration)
         self.done_migrate(migration)
+        self.send_ran_migration(migration)
+        return result
 
 
 class Forwards(Migrator):
     def status(self, migration):
         return ' > %s' % migration
 
+    def forwards(self, migration):
+        return self._wrap_direction(migration.forwards(), self.orm(migration))
+
+    direction = forwards
+
+    def orm(self, migration):
+        return migration.orm()
+
     @staticmethod
     def record(migration):
         # Record us as having done this
     def status(self, migration):
         return ' < %s' % migration
 
+    def orm(self, migration):
+        return migration.prev_orm()
+
+    direction = Migrator.backwards
+
     @staticmethod
     def record(migration):
         # Record us as having not done this
     backwards = []
     if target_name == None:
         target = migrations[-1]
+        target_name = target.name()
     if target_name == "zero":
         backwards = migrations[0].backwards_plan()
     else:

south/migration/base.py

         return migration
     migration = memoize(migration)
 
+    def migration_class(self):
+        return self.migration().Migration
+
+    def migration_instance(self):
+        return self.migration_class()()
+    migration_instance = memoize(migration_instance)
+
     def previous(self):
         index = self.migrations.index(self) - 1
         if index < 0:
         result = [self.previous()]
         if result[0] is None:
             result = []
-        migclass = self.migration().Migration
         # Get forwards dependencies
-        for app, name in getattr(migclass, 'depends_on', []):
+        for app, name in getattr(self.migration_class(), 'depends_on', []):
             try:
                 migrations = Migrations(app)
             except ImproperlyConfigured:
         return self._dependents
     dependents = memoize(dependents)
 
+    def forwards(self):
+        return self.migration_instance().forwards
+
+    def backwards(self):
+        return self.migration_instance().backwards
+
     def forwards_plan(self):
         """
         Returns a list of Migration objects to be applied, in order.
         return LazyFakeORM(self.migration().Migration, self.app_name())
     orm = memoize(orm)
 
+    def no_dry_run(self):
+        migration_class = self.migration_class()
+        try:
+            return migration_class.no_dry_run
+        except AttributeError:
+            return False
+
 
 def get_app_name(app):
     """