dairiki avatar dairiki committed 32fb7e8

Make version table name configurable.

Comments (0)

Files changed (4)

   MySQL needs the constraint type
   in order to emit a DROP CONSTRAINT. #44
 
+- [feature] Added version_table argument to
+  EnvironmentContext.configure(), allowing for the
+  configuration of the version table name. #34
+
 0.3.2
 =====
 - [feature] Basic support for Oracle added, 

alembic/environment.py

          when using ``--sql`` mode.
         :param tag: a string tag for usage by custom ``env.py`` scripts.  
          Set via the ``--tag`` option, can be overridden here.
-     
+        :param version_table: The name of the Alembic version table.
+         The default is ``'alembic_version'``.
+
         Parameters specific to the autogenerate feature, when 
         ``alembic revision`` is run with the ``--autogenerate`` feature:
     

alembic/migration.py

 import logging
 log = logging.getLogger(__name__)
 
-_meta = MetaData()
-_version = Table('alembic_version', _meta, 
-                Column('version_num', String(32), nullable=False)
-            )
-
 class MigrationContext(object):
     """Represent the database state made available to a migration 
     script.
                                             'compare_server_default', 
                                             False)
 
+        version_table = opts.get('version_table', 'alembic_version')
+        self._version = Table(
+            version_table, MetaData(),
+            Column('version_num', String(32), nullable=False))
+
         self._start_from_rev = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
                             dialect, self.connection, self.as_sql,
                 raise util.CommandError(
                     "Can't specify current_rev to context "
                     "when using a database connection")
-            _version.create(self.connection, checkfirst=True)
-        return self.connection.scalar(_version.select())
+            self._version.create(self.connection, checkfirst=True)
+        return self.connection.scalar(self._version.select())
 
     _current_rev = get_current_revision
     """The 0.2 method name, for backwards compat."""
         if old == new:
             return
         if new is None:
-            self.impl._exec(_version.delete())
+            self.impl._exec(self._version.delete())
         elif old is None:
-            self.impl._exec(_version.insert().
+            self.impl._exec(self._version.insert().
                         values(version_num=literal_column("'%s'" % new))
                     )
         else:
-            self.impl._exec(_version.update().
+            self.impl._exec(self._version.update().
                         values(version_num=literal_column("'%s'" % new))
                     )
 
             if current_rev is False:
                 current_rev = prev_rev
                 if self.as_sql and not current_rev:
-                    _version.create(self.connection)
+                    self._version.create(self.connection)
             log.info("Running %s %s -> %s", change.__name__, prev_rev, rev)
             if self.as_sql:
                 self.impl.static_output(
                 self._update_current_rev(current_rev, rev)
 
             if self.as_sql and not rev:
-                _version.drop(self.connection)
+                self._version.drop(self.connection)
 
     def execute(self, sql):
         """Execute a SQL construct or string statement.

tests/test_version_table.py

+import unittest
+
+from sqlalchemy import Table, MetaData, Column, String, create_engine
+from sqlalchemy.engine.reflection import Inspector
+
+from alembic.util import CommandError
+
+version_table = Table('version_table', MetaData(),
+                      Column('version_num', String(32), nullable=False))
+
+class TestMigrationContext(unittest.TestCase):
+    _bind = []
+
+    @property
+    def bind(self):
+        if not self._bind:
+            engine = create_engine('sqlite:///', echo=True)
+            self._bind.append(engine)
+        return self._bind[0]
+
+    def setUp(self):
+        self.connection = self.bind.connect()
+        self.transaction = self.connection.begin()
+
+    def tearDown(self):
+        version_table.drop(self.connection, checkfirst=True)
+        self.transaction.rollback()
+
+    def make_one(self, **kwargs):
+        from alembic.migration import MigrationContext
+        return MigrationContext.configure(**kwargs)
+
+    def get_revision(self):
+        result = self.connection.execute(version_table.select())
+        rows = result.fetchall()
+        if len(rows) == 0:
+            return None
+        self.assertEqual(len(rows), 1)
+        return rows[0]['version_num']
+
+    def test_config_default_version_table_name(self):
+        context = self.make_one(dialect_name='sqlite')
+        self.assertEqual(context._version.name, 'alembic_version')
+
+    def test_config_explicit_version_table_name(self):
+        context = self.make_one(dialect_name='sqlite',
+                                opts={'version_table': 'explicit'})
+        self.assertEqual(context._version.name, 'explicit')
+
+    def test_get_current_revision_creates_version_table(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+        self.assertEqual(context.get_current_revision(), None)
+        insp = Inspector(self.connection)
+        self.assertTrue('version_table' in insp.get_table_names())
+
+    def test_get_current_revision(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+        version_table.create(self.connection)
+        self.assertEqual(context.get_current_revision(), None)
+        self.connection.execute(
+            version_table.insert().values(version_num='revid'))
+        self.assertEqual(context.get_current_revision(), 'revid')
+
+    def test_get_current_revision_error_if_starting_rev_given_online(self):
+        context = self.make_one(connection=self.connection,
+                                opts={'starting_rev': 'boo'})
+        self.assertRaises(CommandError, context.get_current_revision)
+
+    def test_get_current_revision_offline(self):
+        context = self.make_one(dialect_name='sqlite',
+                                opts={'starting_rev': 'startrev',
+                                      'as_sql': True})
+        self.assertEqual(context.get_current_revision(), 'startrev')
+
+    def test__update_current_rev(self):
+        version_table.create(self.connection)
+        context = self.make_one(connection=self.connection,
+                                opts={'version_table': 'version_table'})
+
+        context._update_current_rev(None, 'a')
+        self.assertEqual(self.get_revision(), 'a')
+        context._update_current_rev('a', 'b')
+        self.assertEqual(self.get_revision(), 'b')
+        context._update_current_rev('b', None)
+        self.assertEqual(self.get_revision(), None)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.