trac-dvbcronrecording-plugin / src / dvbcronrecording / db / schemachange.py

# NOTE: I just came across this, a migration tool from the author of sqlalchemy
# https://bitbucket.org/zzzeek/alembic


import schemadiff
import tsab2 as tsab
import sys

import sqlalchemy

def alter_sql_AtoB(metadataA, metadataB, env, excludeTables = None):
    """ given a SchemaDiff instance, generate the corresponding
        sql script to actually change the tables """
    diff = schemadiff.SchemaDiff(metadataA, metadataB, excludeTables = excludeTables)
    changer = Changer(diff, env)
    for change in changer.sqlAtoB():
            yield change

def alter_sql_BtoA(metadataA, metadataB, env, excludeTables = None):
    """ given a SchemaDiff instance, generate the corresponding
        sql script to actually change the tables """
    diff = schemadiff.SchemaDiff(metadataA, metadataB, excludeTables = excludeTables)
    changer = Changer(diff, env)
    for change in changer.sqlBtoA():
            yield change

def migrate_AtoB(metadataA, metadataB, env, excludeTables = None):
    """ given a SchemaDiff instance, generate the corresponding
        sql script to actually change the tables """
    diff = schemadiff.SchemaDiff(metadataA, metadataB, excludeTables = excludeTables)
    changer = Changer(diff, env)
    for change in changer.migrateAtoB():
            yield change

def migrate_BtoA(metadataA, metadataB, env, excludeTables = None):
    """ given a SchemaDiff instance, generate the corresponding
        sql script to actually change the tables """
    diff = schemadiff.SchemaDiff(metadataA, metadataB, excludeTables = excludeTables)
    changer = Changer(diff, env)
    for change in changer.migrateBtoA():
            yield change

import sqlalchemy.sql.expression
class SchemaChangeElement(sqlalchemy.sql.expression.ClauseElement):
    __visit_name__ = "schema_change"
    
class Changer:
    def __init__(self, schemadiff, env):
        self.schemadiff = schemadiff
        self.env = env
        self.mode = None
        self._ddl_dialect = None
        self._ddl_compiler = None
        self._recreated = []
        self._dialect = None
        self._engine = None
        self._session = None
    def engine(self):
        print >> sys.stderr, "ENGINE"
        if self._engine is None:
            self._engine = tsab.engine(self.env)
        return self._engine
    def session(self):
        if self._session is None:
            self._session = tsab.session(self.env)
        return self._session
    def commit(self):
        if self._session is not None:
            self._session.commit()
    def dialect(self):
        if self._dialect is None:
            # print self.engine().driver
            import trac.db 
            self._dialect = trac.db.DatabaseManager(self.env).connection_uri.split(':')[0]
        return self._dialect
    def sqlAtoB(self):
        diff = self.schemadiff
        for tablename in diff.tables_missing_from_A:
            for sql in self.create(False, diff.metadataA, diff.metadataB, tablename):
                yield sql
        for tablename in diff.tables_missing_from_B:
            for sql in self.drop(False, diff.metadataA, diff.metadataB, tablename):
                yield sql
        for tablename, td in sorted(diff.tables_different.items()):
            for sql in self.append(False, diff.metadataA, diff.metadataB, tablename, td.columns_missing_from_A):
                yield sql
            for sql in self.remove(False, diff.metadataA, diff.metadataB, tablename, td.columns_missing_from_B):
                yield sql
            newlist = [ (a,b) for a, b in td.columns_different.items() ]
            for sql in self.change(False, diff.metadataA, diff.metadataB, tablename, newlist):
                yield sql
    def sqlBtoA(self):
        diff = self.schemadiff
        for tablename in diff.tables_missing_from_B:
            for sql in self.create(False, diff.metadataB, diff.metadataA, tablename):
                yield sql
        for tablename in diff.tables_missing_from_A:
            for sql in self.drop(False, diff.metadataB, diff.metadataA, tablename):
                yield sql
        for tablename, td in sorted(diff.tables_different.items()):
            for sql in self.append(False, diff.metadataB, diff.metadataA, tablename, td.columns_missing_from_B):
                yield sql
            for sql in self.remove(False, diff.metadataB, diff.metadataA, tablename, td.columns_missing_from_A):
                yield sql
            if tablename in self._recreated: continue
            newlist = [ (b, a) for a, b in td.columns_different.items() ]
            for sql in self.change(False, diff.metadataB, diff.metadataA, tablename, newlist):
                yield sql
    def migrateAtoB(self):
        diff = self.schemadiff
        for tablename in diff.tables_missing_from_A:
            for sql in self.create(True, diff.metadataA, diff.metadataB, tablename):
                yield sql
        for tablename in diff.tables_missing_from_B:
            for sql in self.drop(True, diff.metadataA, diff.metadataB, tablename):
                yield sql
        for tablename, td in sorted(diff.tables_different.items()):
            for sql in self.append(True, diff.metadataA, diff.metadataB, tablename, td.columns_missing_from_A):
                yield sql
            for sql in self.remove(True, diff.metadataA, diff.metadataB, tablename, td.columns_missing_from_B):
                yield sql
            newlist = [ (a,b) for a, b in td.columns_different.items() ]
            for sql in self.change(True, diff.metadataA, diff.metadataB, tablename, newlist):
                yield sql
        self.commit()
    def migrateBtoA(self):
        diff = self.schemadiff
        for tablename in diff.tables_missing_from_B:
            for sql in self.create(True, diff.metadataB, diff.metadataA, tablename):
                yield sql
        for tablename in diff.tables_missing_from_A:
            for sql in self.drop(True, diff.metadataB, diff.metadataA, tablename):
                yield sql
        for tablename, td in sorted(diff.tables_different.items()):
            for sql in self.append(True, diff.metadataB, diff.metadataA, tablename, td.columns_missing_from_B):
                yield sql
            for sql in self.remove(True, diff.metadataB, diff.metadataA, tablename, td.columns_missing_from_A):
                yield sql
            if tablename in self._recreated: continue
            newlist = [ (b, a) for a, b in td.columns_different.items() ]
            for sql in self.change(True, diff.metadataB, diff.metadataA, tablename, newlist):
                yield sql
        self.commit()
    def ddl_dialect(self):
        if self._ddl_dialect is None:
            if self.dialect() in [ "sqlite" ]:
                import sqlalchemy.dialects.sqlite.base
                q = sqlalchemy.dialects.sqlite.base.SQLiteDialect()
                self._ddl_dialect = q
            else:
                print "unkonwn dialect", self.dialect()
        return self._ddl_dialect
    def ddl_compiler(self):
        """ we do all DDL compilation ourselves. But we take advantage of
            the ddl.get_column_specification(coldef) sql conversions. """
        if self._ddl_compiler is None:
            statement = SchemaChangeElement()
            if self.dialect() in [ "sqlite" ]:
                # import sqlalchemy.dialects.sqlite.base
                dialect = self.ddl_dialect()
                import sqlalchemy.dialects.sqlite.base
                class SQLiteDDL(sqlalchemy.dialects.sqlite.base.SQLiteDDLCompiler):
                    def visit_schema_change(self, s):
                        """ SQLAlchemy 0.7 compiles immediately on __init__ """
                        pass
                q = SQLiteDDL(dialect, statement)
                # q = dialect.ddl_compiler(dialect, statement)
                self._ddl_compiler = q
            else:
                print "unkonwn dialect", self.dialect()
        return self._ddl_compiler
    def table(self, metadata, tablename):
        for name, defs in metadata.tables.items(): 
            if  name == tablename:
                return defs
        return None
    def create(self, migrate, metadata1, metadata2, table):
        # self.show(metadata2)
        ddl = self.ddl_compiler()
        tabledef = self.table(metadata2, table)
        specs = []
        for name in tabledef.columns.keys():
            coldef = tabledef.columns.get(name)
            spec = ddl.get_column_specification(coldef)
            specs += [ spec ] 
        sql = "CREATE TABLE %s (%s)" % (table, ", ".join(specs))
        if migrate:
            self.engine().create(tabledef)
            # self.execute(migrate, sql)
        yield sql
        self._recreated += [ table ]
    def drop(self, migrate, metadata1, metadata2, table):
        sql = "DROP TABLE %s" % table
        if migrate:
            tabledef = self.table(metadata2, table)
            self.engine().drop(tabledef)
        yield sql
    def drops(self, migrate, table):
        """ drops a renamed table """
        sql = "DROP TABLE %s" % table
        for sql in self.execute(migrate, sql):
            yield sql
    def append(self, migrate, metadata1, metadata2, table, newcols):
        if self.dialect() in [ "sqlite", "postres", "oracle"] or True:
            ddl = self.ddl_compiler()
            for newcol in newcols:
                coldef = self.table(metadata2, table).columns.get(newcol)
                spec = ddl.get_column_specification(coldef)
                sql = "ALTER TABLE %s ADD COLUMN %s" % (table, spec)
                for sql in self.execute(migrate, sql):
                    yield sql
        else:
            columnnames = self.table(metadata1, table).columns.keys()
            temp_table = "migration_"+table
            for sql in self.rename(migrate, metadata1, metadata2, table,temp_table):
                yield sql
            for sql in self.create(migrate, metadata1, metadata2, table):
                yield sql
            sql = "INSERT INTO %s (%s) SELECT %s FROM %s" % (table, 
                                                             ",".join(columnnames),
                                                             ",".join(columnnames),
                                                             temp_table)
            for sql in self.execute(migrate, sql):
                yield sql
            for sql in self.drops(migrate, temp_table):
                yield sql
    def remove(self, migrate, metadata1, metadata2, table, oldcols):
        if self.dialect() in [ "firebird "]:
            # unsupported by SQLite
            for oldcol in oldcols:
                sql = "ALTER TABLE %s DROP %s" % (table, oldcol)
                for sql in self.execute(migrate, sql):
                    yield sql
        elif self.dialect() in [ "oracle", "postgres"]:
            # unsupported by SQLite
            for oldcol in oldcols:
                sql = "ALTER TABLE %s DROP COLUMN %s" % (table, oldcol)
                for sql in self.execute(migrate, sql):
                    yield sql
        else:
            columnnames = self.table(metadata2, table).columns.keys()
            keptcols = [ col for col in columnnames if col not in oldcols ]
            temp_table = "migration_"+table
            for sql in self.rename(migrate, metadata1, metadata2, table,temp_table):
                yield sql
            for sql in self.create(migrate, metadata1, metadata2, table):
                yield sql
            sql = "INSERT INTO %s (%s) SELECT %s FROM %s" % (table, 
                                                             ",".join(keptcols),
                                                             ",".join(keptcols),
                                                             temp_table)
            for sql in self.execute(migrate, sql):
                yield sql
            for sql in self.drops(migrate, temp_table):
                yield sql
    def change(self, migrate, metadata1, metadata2, table, changed):
        if self.dialect() in [ "firebird"]:
            # unsupported by SQLite
            ddl = self.ddl_compiler()
            for newcol in changed.values():
                coldef = self.table(metadata2, table).columns.get(newcol)
                spec = ddl.get_column_specification(coldef)
                sql = "ALTER TABLE %s ALTER COLUMN %s TO %s" % (table, newcol, spec)
                for sql in self.execute(migrate, sql):
                    yield sql
        elif self.dialect() in [ "oracle"]:
            # unsupported by SQLite
            ddl = self.ddl_compiler()
            for newcol in changed.values():
                coldef = self.table(metadata2, table).columns.get(newcol)
                spec = ddl.get_column_specification(coldef)
                sql = "ALTER TABLE %s MODIFY COLUMN (%s)" % (table, spec)
                for sql in self.execute(migrate, sql):
                    yield sql
        else:
            temp_table = "migration_"+table
            for sql in self.rename(migrate, metadata1, metadata2, table, temp_table):
                yield sql
            for sql in self.create(migrate, metadata1, metadata2, table):
                yield sql
            sql = "INSERT INTO %s SELECT * FROM %s" % (table, temp_table)
            if migrate: 
                self.session().execute(sql)
            yield sql
            for sql in self.drops(migrate, temp_table):
                yield sql
    def rename(self, migrate, metadata1, metadata2, table, temp_table):
        if self.dialect() in [ "sqlite", "oracle", "postgres"] or True:
            sql = "ALTER TABLE %s RENAME TO %s" % (table, temp_table)
            for sql in self.execute(migrate, sql):
                yield sql
        else:
            sql = "ALTER TABLE %s RENAME TO %s" % (table, temp_table)
            ddl = self.ddl_compiler()
            tabledef = self.table(metadata1, table) # OLD TABLE
            specs = []
            for name in tabledef.columns.keys():
                coldef = tabledef.columns.get(name)
                spec = ddl.get_column_specification(coldef)
                specs += [ spec ] 
            sql = "CREATE TABLE %s (%s)" % (temp_table, ", ".join(specs))
            for sql in self.execute(migrate, sql):
                yield sql
            sql = "INSERT INTO %s SELECT * FROM %s" % (temp_table, table)
            for sql in self.execute(migrate, sql):
                yield sql
            self._recreated += [ table ]
    def execute(self, migrate, sql):
        if migrate:
            self.session().execute(sql)
        yield sql
    def show(self, metadata):
        for table in metadata.tables.values():
            for coldef in table.columns:
                print str(coldef)
        
        
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.