Commits

wilsaj committed 2669f63 Draft

get autogenerate working with postgresql schemas

Comments (0)

Files changed (1)

alembic/autogenerate.py

 from sqlalchemy.engine.reflection import Inspector
 from sqlalchemy.util import OrderedSet
 from sqlalchemy import schema, types as sqltypes
+import itertools
 import re
 
 import logging
                             include_symbol=None):
     inspector = Inspector.from_engine(connection)
     # TODO: not hardcode alembic_version here ?
-    conn_table_names = set(inspector.get_table_names()).\
-                            difference(['alembic_version'])
+    schema_names= set(inspector.get_schema_names()).difference(['information_schema', 'public'])
+    conn_table_names = set(itertools.chain(*[
+        [schema_name + '.' + table_name
+         for table_name in inspector.get_table_names(schema=schema_name)]
+        for schema_name in schema_names]))
+    #conn_table_names = set(inspector.get_table_names()).\
+                            #difference(['alembic_version'])
 
 
-    metadata_table_names = OrderedSet([table.name
+    #metadata_table_names = OrderedSet([table.name
+                                #for table in metadata.sorted_tables])
+
+    metadata_table_names = OrderedSet([table.schema + '.' + table.name
                                 for table in metadata.sorted_tables])
-
     if include_symbol:
         conn_table_names = set(name for name in conn_table_names
                             if include_symbol(name))
 
 def _compare_tables(conn_table_names, metadata_table_names,
                     inspector, metadata, diffs, autogen_context):
+    import pdb; pdb.set_trace()
     for tname in metadata_table_names.difference(conn_table_names):
         diffs.append(("add_table", metadata.tables[tname]))
         log.info("Detected added table %r", tname)
 
     removal_metadata = schema.MetaData()
-    for tname in conn_table_names.difference(metadata_table_names):
+    for tfullname in conn_table_names.difference(metadata_table_names):
+        tschema, tname = tfullname.split('.')
         exists = tname in removal_metadata.tables
-        t = schema.Table(tname, removal_metadata)
+        t = schema.Table(tname, removal_metadata, schema=tschema)
         if not exists:
             inspector.reflecttable(t, None)
         diffs.append(("remove_table", t))
         log.info("Detected removed table %r", tname)
 
-    existing_tables = conn_table_names.intersection(metadata_table_names)
+    existing_tables = [tfullname.split('.')
+            for tfullname in conn_table_names.intersection(metadata_table_names)]
 
     conn_column_info = dict(
         (tname,
             dict(
                 (rec["name"], rec)
-                for rec in inspector.get_columns(tname)
+                for rec in inspector.get_columns(tname, schema=tschema)
             )
         )
-        for tname in existing_tables
+        for tschema, tname in existing_tables
     )
 
-    for tname in sorted(existing_tables):
+    for tschema, tname in sorted(existing_tables):
         _compare_columns(tname,
                 conn_column_info[tname],
-                metadata.tables[tname],
+                metadata.tables[tschema + '.' + tname],
                 diffs, autogen_context)
 
     # TODO: