Commits

dieselmachine committed 376263d

initial commit

Comments (0)

Files changed (6)

alembic/autogenerate.py

     _produce_net_changes(connection, metadata, diffs, autogen_context)
     return diffs
 
+def get_numeric_input(max_value):
+    while True:
+        cmd = raw_input('Enter a number 1 to %s: ' % max_value)
+        if not cmd:
+            print 'You must enter a value'
+            continue
+        if not cmd.isdigit():
+            print "'%s' isn't a number" % cmd
+            continue
+        cmd = int(cmd)
+        if cmd > max_value:
+            print 'Value must be between 1 and %s' % max_value
+            continue
+        return cmd
+
 ###################################################
 # top level
 
-def _produce_migration_diffs(context, template_args, imports):
+def _produce_migration_diffs(context, template_args, imports,
+                             interactive=False, initial=False):
     opts = context.opts
     metadata = opts['target_metadata']
     if metadata is None:
     autogen_context, connection = _autogen_context(context, imports)
 
     diffs = []
-    _produce_net_changes(connection, metadata, diffs, autogen_context)
+    _produce_net_changes(connection, metadata, diffs, autogen_context,
+        interactive, initial)
     template_args[opts['upgrade_token']] = \
             _indent(_produce_upgrade_commands(diffs, autogen_context))
     template_args[opts['downgrade_token']] = \
 ###################################################
 # walk structures
 
-def _produce_net_changes(connection, metadata, diffs, autogen_context):
+def _produce_net_changes(connection, metadata, diffs, autogen_context,
+                         interactive=False, initial=False):
     inspector = Inspector.from_engine(connection)
-    # TODO: not hardcode alembic_version here ?
-    conn_table_names = set(inspector.get_table_names()).\
-                            difference(['alembic_version'])
-    metadata_table_names = OrderedSet([table.name for table in metadata.sorted_tables])
+    metadata_table_names = OrderedSet([table.fullname \
+        for table in metadata.sorted_tables])
+    conn_table_names = set()
+    
+    default_schema = metadata.bind.url.database if metadata.bind else None
+    active_schemas = set([table.schema or default_schema \
+        for table in metadata.sorted_tables])
+    
+    if not initial:
+        for schema in active_schemas:
+            """
+            we need these names to match the keys in metadata_table_names
+            which means including the schema if it exists on the table
+            """
+            for name in inspector.get_table_names(schema=schema):
+                if schema == default_schema:
+                    if '%s.%s' % (schema,name) in metadata_table_names:
+                        conn_table_names.add('%s.%s' % (schema, name))
+                    else:
+                        conn_table_names.add(name)
+                else:
+                    conn_table_names.add('%s.%s' % (schema, name))
 
     _compare_tables(conn_table_names, metadata_table_names,
-                    inspector, metadata, diffs, autogen_context)
+                    inspector, metadata, diffs, autogen_context, interactive)
 
 def _compare_tables(conn_table_names, metadata_table_names, 
-                    inspector, metadata, diffs, autogen_context):
-    for tname in metadata_table_names.difference(conn_table_names):
-        diffs.append(("add_table", metadata.tables[tname]))
-        log.info("Detected added table %r", tname)
+                    inspector, metadata, diffs, autogen_context, 
+                    interactive=False):
+    added_tables = metadata_table_names.difference(conn_table_names)
+    removed_tables = conn_table_names.difference(metadata_table_names)
+    existing_tables = conn_table_names.intersection(metadata_table_names)
+    ignored_tables = autogen_context['context'].get_ignored_tables()
+
+    removed_tables = [t for t in removed_tables if t not in ignored_tables]
+
+    conn_column_info = {}
+    conn_metadata = schema.MetaData()
+
+    for tname in existing_tables:
+        sch, name = (None, tname)
+        if tname.find('.') > -1:
+            sch, name = tname.rsplit('.',1)
+
+        exists = tname in conn_metadata.tables
+        t = schema.Table(name, conn_metadata, schema=sch)
+        if not exists:
+            inspector.reflecttable(t, None)
+
+        conn_column_info[tname] = dict(
+            (rec["name"], rec)
+            for rec in inspector.get_columns(name, schema=sch)
+        )
+            
+    for tname in added_tables:
+        sch, name = (None, tname)
+        if tname.find('.') > -1:
+            sch, name = tname.rsplit('.',1)
+
+        if interactive:
+            if sch:
+                " only detect renames if they happened within the same schema "
+                potential_renames = [table for table in removed_tables \
+                    if table.startswith('%s.' % sch)]
+            else:
+                " no schema provided, only check local tables "
+                potential_renames = [table for table in removed_tables \
+                    if table.find('.') == -1]
+                    
+            if not potential_renames:
+                diffs.append(("add_table", metadata.tables[tname]))
+                log.info("Detected added table %r", tname)
+            else:
+                print 'detected added table %s. Action?' % tname
+                options = [
+                    (1, 'This is a new table'),
+                    ]
+                for i, table in enumerate(potential_renames):
+                    options.append( (i+2, 'This table was renamed from %s' % table) )
+                for option in options:
+                    print '%s. %s' % option
+                cmd = get_numeric_input(len(options))
+                if cmd == 1:
+                    diffs.append(("add_table", metadata.tables[tname]))
+                    log.info("Detected added table %r", tname)
+                else:
+                    " this is a rename "
+                    old_name = potential_renames[cmd-2]
+                    removed_tables.remove(old_name)
+                    existing_tables.add(tname)
+                    conn_column_info[tname] = dict(
+                        (rec["name"], rec)
+                        for rec in inspector.get_columns(old_name)
+                        )
+                    diffs.append(("rename_table", old_name, tname, sch))
+        else:
+            diffs.append(("add_table", metadata.tables[tname]))
+            log.info("Detected added table %r", tname)
+            
+
+    if interactive and removed_tables:
+        itemized = False
+        for tname in removed_tables:
+            print 'Detected deleted table "%s"' % tname
+        print 'Found %s deleted tables. Action?' % len(removed_tables)
+        options = [
+            (1, 'Mark all tables as deleted'),
+            (2, 'Add all tables to ignore list'),
+            (3, 'Itemized actions'),
+            ]
+        for option in options:
+            print '%s. %s' % option
+        cmd = get_numeric_input(len(options))
+        if cmd == 2:
+            ignore = autogen_context['context']._ignored
+            for t in removed_tables:
+                autogen_context['connection'].execute(
+                    ignore.insert().prefix_with('ignore'), table=t
+                )
+            removed_tables = set()
+        elif cmd == 3:
+            itemized = True
 
     removal_metadata = schema.MetaData()
-    for tname in conn_table_names.difference(metadata_table_names):
+    for tname in removed_tables:
+        sch, name = (None, tname)
+        if tname.find('.') > -1:
+            sch, name = tname.rsplit('.',1)
+
+        if interactive and itemized:
+            print 'Found deleted table "%s". Action?' % tname
+            options = [
+                (1, 'Add to ignore list'),
+                (2, 'Mark as deleted'),
+                ]
+            for option in options:
+                print '%s. %s' % option
+
+            cmd = get_numeric_input(len(options))
+            if cmd == 1:
+                ignore = autogen_context['context']._ignored
+                autogen_context['connection'].execute(
+                    ignore.insert().prefix_with('ignore'), table=tname
+                )
+                continue
+
         exists = tname in removal_metadata.tables
-        t = schema.Table(tname, removal_metadata)
+        t = schema.Table(name, removal_metadata, schema=sch)
         if not exists:
             inspector.reflecttable(t, None)
         diffs.append(("remove_table", t))
         log.info("Detected removed table %r", tname)
 
-    existing_tables = conn_table_names.intersection(metadata_table_names)
-
-    conn_column_info = dict(
-        (tname, 
-            dict(
-                (rec["name"], rec)
-                for rec in inspector.get_columns(tname)
-            )
-        )
-        for tname in existing_tables
-    )
-
     for tname in sorted(existing_tables):
         _compare_columns(tname, 
                 conn_column_info[tname], 
                 metadata.tables[tname],
-                diffs, autogen_context)
-
+                diffs, autogen_context,
+                interactive=interactive)
     # TODO: 
     # index add/drop
     # table constraints
 ###################################################
 # element comparison
 
-def _compare_columns(tname, conn_table, metadata_table, diffs, autogen_context):
+def _compare_columns(tname, conn_table, metadata_table, diffs, autogen_context,
+                     interactive=False, schema=None):
+    sch, name = (schema, tname)
+    if tname.find('.') > -1:
+        sch, name = tname.rsplit('.',1)
+                     
     metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c)
     conn_col_names = set(conn_table)
     metadata_col_names = set(metadata_cols_by_name)
 
-    for cname in metadata_col_names.difference(conn_col_names):
-        diffs.append(
-            ("add_column", tname, metadata_cols_by_name[cname])
-        )
-        log.info("Detected added column '%s.%s'", tname, cname)
+    added_columns = metadata_col_names.difference(conn_col_names)
+    removed_columns = conn_col_names.difference(metadata_col_names)
+    existing_columns = metadata_col_names.intersection(conn_col_names)
+    renamed_columns = {} # k:old_name, v:new_name
 
-    for cname in conn_col_names.difference(metadata_col_names):
+    for cname in added_columns:
+        if interactive:
+            if not removed_columns:
+                diffs.append(("add_column", tname, metadata_cols_by_name[cname]))
+                log.info("Detected added column '%s.%s'", tname, cname)
+            else:
+                print 'detected added column %s. Action?' % cname
+                potential_renames = list(removed_columns)
+                options = [
+                    (1, 'This is a new column'),
+                    ]
+                for i, col in enumerate(potential_renames):
+                    options.append((i+2, 'This column was renamed from %s' % col))
+                for option in options:
+                    print '%s. %s' % option
+                cmd = get_numeric_input(len(options))
+                if cmd == 1:
+                    diffs.append(("add_column", tname, metadata_cols_by_name[cname]))
+                    log.info("Detected added column '%s.%s'", tname, cname)
+                else:
+                    " this is a rename "
+                    old_name = potential_renames[cmd-2]
+                    removed_columns.remove(old_name)
+                    existing_columns.add(old_name)
+                    renamed_columns[old_name] = cname
+        else:
+            diffs.append(("add_column", tname, metadata_cols_by_name[cname]))
+            log.info("Detected added column '%s.%s'", tname, cname)
+
+    for cname in removed_columns:
         diffs.append(
             ("remove_column", tname, schema.Column(
                 cname,
         )
         log.info("Detected removed column '%s.%s'", tname, cname)
 
-    for colname in metadata_col_names.intersection(conn_col_names):
-        metadata_col = metadata_table.c[colname]
+    for colname in existing_columns:
+        metadata_colname = renamed_columns.get(colname, colname)
+        metadata_col = metadata_table.c[metadata_colname]
         conn_col = conn_table[colname]
         col_diff = []
         _compare_type(tname, colname,
             metadata_col,
             col_diff, autogen_context
         )
+        _compare_name(tname, colname,
+            conn_col,
+            metadata_col,
+            col_diff, autogen_context
+        )
         if col_diff:
             diffs.append(col_diff)
 
     isdiff = autogen_context['context']._compare_type(conn_col, metadata_col)
 
     if isdiff:
-
         diffs.append(
             ("modify_type", tname, cname, 
                     {
             cname
         )
 
-
+def _compare_name(tname, cname, conn_col, metadata_col, diffs,
+                            autogen_context):
+    isdiff = (conn_col['name'] != metadata_col.name)
+    if isdiff:
+        diffs.append(
+            ("modify_name", tname, cname, 
+                {
+                    "existing_nullable": conn_col['nullable'],
+                    "existing_type": conn_col['type'],
+                    "existing_server_default": conn_col['default'],
+                    "name": metadata_col.name,
+                },
+            ),
+        )
+        log.info("Detected renamed column: %s.%s -> %s.%s",
+            tname,
+            conn_col['name'],
+            tname,
+            metadata_col.name
+        )
+                    
 ###################################################
 # produce command structure
 
 
 def _invoke_command(updown, args, autogen_context):
     if isinstance(args, tuple):
+        if args[0] == 'rename_table':
+            return _invoke_rename_table_command(updown, args, autogen_context)
         return _invoke_adddrop_command(updown, args, autogen_context)
     else:
         return _invoke_modify_command(updown, args, autogen_context)
 
+def _invoke_rename_table_command(updown, args, autogen_context):
+    old_name, new_name = args[1:3]
+    if len(args) > 3:
+        schema = args[3]
+    else:
+        schema = None
+
+    if updown == 'upgrade':
+        return _rename_table(old_name, new_name, autogen_context, schema=schema)
+    else:
+        return _rename_table(new_name, old_name, autogen_context, schema=schema)
+
 def _invoke_adddrop_command(updown, args, autogen_context):
     cmd_type = args[0]
     adddrop, cmd_type = cmd_type.split("_")
 
 def _invoke_modify_command(updown, args, autogen_context):
     tname, cname = args[0][1:3]
-    kw = {}
+    sch, name = (None, tname)
+    if tname.find('.') > -1:
+        sch, name = tname.rsplit('.',1)
+    kw = { 'schema': sch }
 
     _arg_struct = {
         "modify_type":("existing_type", "type_"),
     }
     for diff in args:
         diff_kw = diff[3]
+        if diff_kw.get('name'):
+            if updown == "upgrade":
+                kw['name'] = diff_kw['name']
+            else:
+                kw['name'] = cname
+                cname = diff_kw['name']
+
         for arg in ("existing_type", \
                 "existing_nullable", \
                 "existing_server_default"):
             if arg in diff_kw:
                 kw.setdefault(arg, diff_kw[arg])
+
+        if len(diff) < 6:
+            " renames do not have positional args "
+            continue
+
         old_kw, new_kw = _arg_struct[diff[0]]
         if updown == "upgrade":
             kw[new_kw] = diff[-1]
         kw.pop("existing_nullable", None)
     if "server_default" in kw:
         kw.pop("existing_server_default", None)
-    return _modify_col(tname, cname, autogen_context, **kw)
+    return _modify_col(name, cname, autogen_context, **kw)
 
 ###################################################
 # render python
 
+def _rename_table(old_name, new_name, autogen_context, schema=None):
+    return "%(prefix)srename_table(%(old_name)r, %(new_name)r%(schema)s)" % {
+            'prefix': _alembic_autogenerate_prefix(autogen_context), 
+            'old_name': old_name, 
+            'new_name': new_name,
+            'schema': ", '%s'" % schema if schema else ''}
+
 def _add_table(table, autogen_context):
     return "%(prefix)screate_table(%(tablename)r,\n%(args)s\n)" % {
         'tablename':table.name,
                 [_render_constraint(cons, autogen_context) for cons in 
                     table.constraints]
                 if rcons is not None
-            ])
+            ] + ['schema="%s"' % table.schema] if table.schema else [])
         ),
     }
 
 def _drop_table(table, autogen_context):
-    return "%(prefix)sdrop_table(%(tname)r)" % {
+    return "%(prefix)sdrop_table(%(tname)r%(schema)s)" % {
             "prefix":_alembic_autogenerate_prefix(autogen_context),
-            "tname":table.name
+            "tname":table.name,
+            "schema": ", schema='%s'" % table.schema if table.schema else '',
         }
 
-def _add_column(tname, column, autogen_context):
-    return "%(prefix)sadd_column(%(tname)r, %(column)s)" % {
+def _add_column(tname, column, autogen_context, schema=None):
+    return "%(prefix)sadd_column(%(tname)r, %(column)s%(schema)s)" % {
             "prefix":_alembic_autogenerate_prefix(autogen_context),
             "tname":tname,
-            "column":_render_column(column, autogen_context)
+            "column":_render_column(column, autogen_context),
+            'schema': ", '%s'" % schema if schema else '',
             }
 
-def _drop_column(tname, column, autogen_context):
-    return "%(prefix)sdrop_column(%(tname)r, %(cname)r)" % {
+def _drop_column(tname, column, autogen_context, schema=None):
+    return "%(prefix)sdrop_column(%(tname)r, %(cname)r%(schema)s)" % {
             "prefix":_alembic_autogenerate_prefix(autogen_context),
             "tname":tname,
-            "cname":column.name
+            "cname":column.name,
+            'schema': ", '%s'" % schema if schema else '',
             }
 
 def _modify_col(tname, cname, 
                 nullable=None,
                 existing_type=None,
                 existing_nullable=None,
-                existing_server_default=False):
+                existing_server_default=False,
+                schema=None,
+                name=None):
     sqla_prefix = _sqlalchemy_autogenerate_prefix(autogen_context)
     indent = " " * 11
     text = "%(prefix)salter_column(%(tname)r, %(cname)r" % {
                             existing_server_default, 
                             autogen_context),
                     )
+    if schema:
+        text += ', \n%sschema_=%r' % (indent, schema)
+    if name:
+        text += ', \n%sname=%r' % (indent, name)
+
     text += ")"
     return text
 

alembic/command.py

     util.msg("Please edit configuration/connection/logging "\
             "settings in %r before proceeding." % config_file)
 
-def revision(config, message=None, autogenerate=False):
+def revision(config, message=None, autogenerate=False, interactive=False,
+             initial=False):
     """Create a new revision file."""
-
     script = ScriptDirectory.from_config(config)
+    current_rev = script.get_revision("head").revision
     template_args = {}
     imports = set()
-    if autogenerate:
+    if autogenerate or interactive or initial:
         util.requires_07("autogenerate")
         def retrieve_migrations(rev, context):
             if script.get_revision(rev) is not script.get_revision("head"):
                 raise util.CommandError("Target database is not up to date.")
-            autogen._produce_migration_diffs(context, template_args, imports)
+            autogen._produce_migration_diffs(context, template_args, imports,
+                interactive, initial)
             return []
 
         with EnvironmentContext(
             fn = retrieve_migrations
         ):
             script.run_env()
-    script.generate_revision(util.rev_id(), message, **template_args)
+    next_rev = current_rev if initial else util.rev_id()
+    script.generate_revision(next_rev, message, **template_args)
 
 
 def upgrade(config, revision, sql=False, tag=None):

alembic/config.py

                             help="Populate revision script with candidate "
                             "migration operations, based on comparison of database to model.")
 
+        if 'interactive' in kwargs:
+            parser.add_argument("--interactive",
+                            action="store_true",
+                            help="Populate revision script with candidate "
+                            "migration operations, based on comparison of database to model. "
+                            "Allows interactive user-input to handle cases which "
+                            "autogenerate cannot (renames).")
+        if 'initial' in kwargs:
+            parser.add_argument("--initial",
+                            action="store_true",
+                            help="Populate revision script with candidate "
+                            "migration operations, based on empty initial db.")
 
         # TODO:
         # --dialect - name of dialect when --sql mode is set - *no DB connections

alembic/ddl/impl.py

                                 existing_nullable=existing_nullable,
                             ))
 
-    def add_column(self, table_name, column):
-        self._exec(base.AddColumn(table_name, column))
+    def add_column(self, table_name, column, **kw):
+        self._exec(base.AddColumn(table_name, column, **kw))
 
     def drop_column(self, table_name, column, **kw):
-        self._exec(base.DropColumn(table_name, column))
+        self._exec(base.DropColumn(table_name, column, **kw))
 
     def add_constraint(self, const):
         if const._create_rule is None or \

alembic/migration.py

             version_table, MetaData(),
             Column('version_num', String(32), nullable=False))
 
+        ignore_table = opts.get('ignore_table', 'alembic_ignored_tables')
+        self._ignored = Table(
+            ignore_table, MetaData(), 
+                Column('table', String(255), primary_key=True),
+            )
+        self._ignored.create(self.connection, checkfirst=True)
+        connection.execute(
+            self._ignored.insert().prefix_with('ignore'),
+            *[{'table':table} for table in [version_table,ignore_table]]
+        )
+
         self._start_from_rev = opts.get("starting_rev")
         self.impl = ddl.DefaultImpl.get_by_dialect(dialect)(
                             dialect, self.connection, self.as_sql,
 
         return MigrationContext(dialect, connection, opts)
 
+    def get_ignored_tables(self):
+        """Returns a list of tables to be ignored when detected as deleted"""
+        return [r[0] for r in self.connection.execute(self._ignored.select())]
 
     def get_current_revision(self):
         """Return the current revision, usually that which is present

alembic/operations.py

                         existing_type=None,
                         existing_server_default=False,
                         existing_nullable=None,
+                        schema_ = None,
     ):
         """Issue an "alter column" instruction using the 
         current migration context.
          is not being changed; else MySQL sets this to NULL.
         """
 
+        if table_name.find('.') > -1:
+            schema_, table_name = table_name.rsplit('.',1)
+
         if existing_type:
             t = self._table(table_name, 
-                        schema.Column(column_name, existing_type)
+                        schema.Column(column_name, existing_type),
+                        schema=schema_,
                     )
             for constraint in t.constraints:
                 if not isinstance(constraint, schema.PrimaryKeyConstraint):
             existing_type=existing_type,
             existing_server_default=existing_server_default,
             existing_nullable=existing_nullable,
+            schema = schema_,
         )
 
         if type_:
-            t = self._table(table_name, schema.Column(column_name, type_))
+            t = self._table(table_name, schema.Column(column_name, type_),
+                            schema=schema_)
             for constraint in t.constraints:
                 if not isinstance(constraint, schema.PrimaryKeyConstraint):
                     self.impl.add_constraint(constraint)
             self._table(name, *columns, **kw)
         )
 
-    def drop_table(self, name):
+    def drop_table(self, name, **kw):
         """Issue a "drop table" instruction using the current 
         migration context.
 
 
         """
         self.impl.drop_table(
-            self._table(name)
+            self._table(name, **kw)
         )
 
     def create_index(self, name, tablename, *columns, **kw):