Commits

Simon Law committed 6d48cfc

Split out get_migrator() and LoadInitialDataMigrator from migrate_app()

Comments (0)

Files changed (4)

south/management/commands/migrate.py

                     fake = fake,
                     db_dry_run = db_dry_run,
                     verbosity = int(options.get('verbosity', 0)),
-                    load_inital_data = not options.get('no_initial_data', False),
+                    load_initial_data = not options.get('no_initial_data', False),
                     skip = skip,
                 )
                 if result is False:

south/migration/__init__.py

 
 import sys
 
-from django.db import models
 from django.core.exceptions import ImproperlyConfigured
-from django.core.management import call_command
 
 from south import exceptions
 from south.models import MigrationHistory
 from south.db import db
 from south.migration.base import all_migrations, Migrations
 from south.migration.migrators import (Backwards, Forwards,
-                                       DryRunMigrator, FakeMigrator)
+                                       DryRunMigrator, FakeMigrator,
+                                       LoadInitialDataMigrator)
 from south.signals import pre_migrate, post_migrate
 
 
     if ghosts:
         raise exceptions.GhostMigrations(ghosts)
 
-def migrate_app(migrations, target_name=None, resolve_mode=None, fake=False, db_dry_run=False, yes=False, verbosity=0, load_inital_data=False, skip=False):
+def get_migrator(direction, db_dry_run, fake, verbosity, load_initial_data):
+    if direction == 1:
+        migrator = Forwards(verbosity=verbosity)
+    elif direction == -1:
+        migrator = Backwards(verbosity=verbosity)
+    else:
+        return None
+    if db_dry_run:
+        migrator = DryRunMigrator(migrator=migrator)
+    elif fake:
+        migrator = FakeMigrator(migrator=migrator)
+    elif load_initial_data:
+        migrator = LoadInitialDataMigrator(migrator=migrator)
+    return migrator
+
+def migrate_app(migrations, target_name=None, resolve_mode=None, fake=False, db_dry_run=False, yes=False, verbosity=0, load_initial_data=False, skip=False):
     
     app_name = migrations.app_name()
     app = migrations._migrations
         print " ! The following options are available:"
         print "    --merge: will just attempt the migration ignoring any potential dependency conflicts."
         sys.exit(1)
-
+    # Perform the migration
+    migrator = get_migrator(direction,
+                            db_dry_run, fake, verbosity, load_initial_data)
+    if verbosity:
+        if migrator:
+            print migrator.title(target)
+        else:
+            print '- Nothing to migrate.'
     if direction == 1:
-        migrator = Forwards(verbosity=verbosity)
-        if db_dry_run:
-            migrator = DryRunMigrator(migrator=migrator)
-        elif fake:
-            migrator = FakeMigrator(migrator=migrator)
-        if verbosity:
-            print " - Migrating forwards to %s." % target_name
-        try:
-            for migration in missing_forwards:
-                result = migrator.migrate(migration)
-                if result is False: # The migrations errored, but nicely.
-                    return False
-        finally:
-            # Call any pending post_syncdb signals
-            db.send_pending_create_signals()
-        # Now load initial data, only if we're really doing things and ended up at current
-        if not fake and not db_dry_run and load_inital_data and target == migrations[-1]:
-            if verbosity:
-                print " - Loading initial data for %s." % app_name
-            # Override Django's get_apps call temporarily to only load from the
-            # current app
-            old_get_apps, models.get_apps = (
-                models.get_apps,
-                lambda: [models.get_app(app_name)],
-            )
-            # Load the initial fixture
-            call_command('loaddata', 'initial_data', verbosity=verbosity)
-            # Un-override
-            models.get_apps = old_get_apps
+        success = migrator.migrate_many(target, missing_forwards)
     elif direction == -1:
-        migrator = Backwards(verbosity=verbosity)
-        if db_dry_run:
-            migrator = DryRunMigrator(migrator=migrator)
-        elif fake:
-            migrator = FakeMigrator(migrator=migrator)
-        if verbosity:
-            print " - Migrating backwards to just after %s." % target_name
-        for migration in present_backwards:
-            migrator.migrate(migration)
-    else:
-        if verbosity:
-            print "- Nothing to migrate."
-    
+        success = migrator.migrate_many(target, present_backwards)
     # Finally, fire off the post-migrate signal
-    post_migrate.send(None, app=app_name)
+    if success:
+        post_migrate.send(None, app=app_name)

south/migration/migrators.py

 from copy import copy
+from cStringIO import StringIO
 import datetime
 import inspect
 import sys
 import traceback
 
+from django.core.management import call_command
+from django.db import models
+
 from south import exceptions
 from south.db import db
 from south.models import MigrationHistory
     def __init__(self, verbosity=0):
         self.verbosity = int(verbosity)
 
+    @staticmethod
+    def title(target):
+        raise NotImplementedError()
+        
+    @staticmethod
+    def status(target):
+        raise NotImplementedError()
+
     def print_status(self, migration):
         status = self.status(migration)
         if self.verbosity and status:
             print status
 
-    def orm(self, migration):
+    @staticmethod
+    def orm(migration):
         raise NotImplementedError()
 
     def backwards(self, migration):
     def direction(self, migration):
         raise NotImplementedError()
 
-    def print_backwards(self, migration):
-        old_debug, old_dry_run = db.debug, db.dry_run
-        db.debug = db.dry_run = True
-        try:
-            self.backwards(migration)()
-        except:
-            db.debug, db.dry_run = old_debug, old_dry_run
-            raise
-
     @staticmethod
     def _wrap_direction(direction, orm):
         args = inspect.getargspec(direction)
             return direction
         return (lambda: direction(orm))
 
+    @staticmethod
+    def record(migration):
+        raise NotImplementedError()
+
+    def run_migration_error(self, migration, extra_info=''):
+        return (' ! Error found during real run of migration! Aborting.\n'
+                '\n'
+                ' ! Since you have a database that does not support running\n'
+                ' ! schema-altering statements in transactions, we have had \n'
+                ' ! to leave it in an interim state between migrations.\n'
+                '%s\n'
+                ' ! The South developers regret this has happened, and would\n'
+                ' ! like to gently persuade you to consider a slightly\n'
+                ' ! easier-to-deal-with DBMS.\n') % extra_info
+
     def run_migration(self, migration):
         migration_function = self.direction(migration)
         db.start_transaction()
         except:
             db.rollback_transaction()
             if not db.has_ddl_transactions:
-                print ' ! Error found during real run of migration! Aborting.'
-                print
-                print ' ! Since you have a database that does not support running'
-                print ' ! schema-altering statements in transactions, we have had to'
-                print ' ! leave it in an interim state between migrations.'
-                if self.torun == 'forwards':
-                    print
-                    print " ! You *might* be able to recover with:"
-                    self.print_backwards(migration)
-                print
-                print ' ! The South developers regret this has happened, and would'
-                print ' ! like to gently persuade you to consider a slightly'
-                print ' ! easier-to-deal-with DBMS.'
+                print self.run_migration_error(migration)
             raise
         else:
             db.commit_transaction()
         self.send_ran_migration(migration)
         return result
 
+    def migrate_many(self, target, migrations):
+        raise NotImplementedError()
+
 
 class MigratorWrapper(object):
     def __init__(self, migrator, *args, **kwargs):
         pass
 
 
+class LoadInitialDataMigrator(MigratorWrapper):
+    def load_initial_data(self, target):
+        if target != target.migrations[-1]:
+            return
+        # Load initial data, if we ended up at target
+        if self.verbosity:
+            print " - Loading initial data for %s." % target.app_name()
+        # Override Django's get_apps call temporarily to only load from the
+        # current app
+        old_get_apps = models.get_apps
+        models.get_apps = lambda: [models.get_app(target.app_name())]
+        try:
+            call_command('loaddata', 'initial_data', verbosity=self.verbosity)
+        finally:
+            models.get_apps = old_get_apps
+
+    def migrate_many(self, target, migrations):
+        migrator = self._migrator
+        result = migrator.__class__.migrate_many(migrator, target, migrations)
+        if result:
+            self.load_initial_data(target)
+        return True
+
+
 class Forwards(Migrator):
     """
     Runs the specified migration forwards, in order.
     """
     torun = 'forwards'
 
-    def status(self, migration):
+    @staticmethod
+    def title(target):
+        return " - Migrating forwards to %s." % target.name()
+
+    @staticmethod
+    def status(migration):
         return ' > %s' % migration
 
+    @staticmethod
+    def orm(migration):
+        return migration.orm()
+
     def forwards(self, migration):
         return self._wrap_direction(migration.forwards(), self.orm(migration))
 
     direction = forwards
 
-    def orm(self, migration):
-        return migration.orm()
-
     @staticmethod
     def record(migration):
         # Record us as having done this
         record.applied = datetime.datetime.utcnow()
         record.save()
 
+    def format_backwards(self, migration):
+        old_debug, old_dry_run = db.debug, db.dry_run
+        db.debug = db.dry_run = True
+        stdout = sys.stdout
+        sys.stdout = StringIO()
+        try:
+            self.backwards(migration)()
+            return sys.stdout.getvalue()
+        except:
+            raise
+        finally:
+            db.debug, db.dry_run = old_debug, old_dry_run
+            sys.stdout = stdout
+
+    def run_migration_error(self, migration, extra_info=''):
+        extra_info = ('\n'
+                      '! You *might* be able to recover with:'
+                      '%s'
+                      '%s' %
+                      (self.format_backwards(migration), extra_info))
+        return super(Forwards, self).run_migration_error(migration, extra_info)
+
+    def migrate_many(self, target, migrations):
+        try:
+            for migration in migrations:
+                result = self.migrate(migration)
+                if result is False: # The migrations errored, but nicely.
+                    return False
+        finally:
+            # Call any pending post_syncdb signals
+            db.send_pending_create_signals()
+        return True
+
 
 class Backwards(Migrator):
     """
     """
     torun = 'backwards'
 
-    def status(self, migration):
+    @staticmethod
+    def title(target):
+        return " - Migrating backwards to just after %s." % target.name()
+
+    @staticmethod
+    def status(migration):
         return ' < %s' % migration
 
-    def orm(self, migration):
+    @staticmethod
+    def orm(migration):
         return migration.prev_orm()
 
     direction = Migrator.backwards
         if record.id is not None:
             record.delete()
 
+    def migrate_many(self, target, migrations):
+        for migration in migrations:
+            self.migrate(migration)
+        return True
 
+
+

south/tests/logic.py

         self.assertEqual(list(migration.MigrationHistory.objects.all()), [])
         
         # Apply them normally
-        migration.migrate_app(migrations, target_name=None, resolve_mode=None, fake=False, verbosity=0)
+        migration.migrate_app(migrations, target_name=None, resolve_mode=None, fake=False, verbosity=0, load_initial_data=True)
         
         # We should finish with all migrations
         self.assertListEqual(
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.