Commits

Alex Grönholm committed a06aac9

Refactored codebase
Replaced static wildcard imports with dynamic, explicit imports
Added boolean/enum autodetection

Comments (0)

Files changed (5)

+Version history
+===============
+
+1.1.0
+-----
+
+* Revamped the API
+
+* Fixed missing class name prefix in primary/secondary joins in relationships
+
+* Instead of wildcard imports, generate explicit imports dynamically (fixes #1)
+
+* Automatically detect Boolean columns based on CheckConstraints
+
+* Skip redundant CheckConstraints for Enum and Boolean columns
+
+
+1.0.0
+-----
+
+* Initial release
     name='sqlacodegen',
     description='Automatic model code generator for SQLAlchemy',
     long_description=readme,
-    version='1.0.0.post2',
+    version='1.1.0.pre1',
     author='Alex Gronholm',
     author_email='sqlacodegen@nextday.fi',
     url='http://pypi.python.org/pypi/sqlacodegen/',

sqlacodegen/codegen.py

 from collections import defaultdict
 from keyword import iskeyword
 import inspect
+import sys
+import re
 
-from sqlalchemy.types import Enum
-from sqlalchemy.schema import ForeignKeyConstraint, PrimaryKeyConstraint, CheckConstraint, UniqueConstraint
+from sqlalchemy import (Enum, ForeignKeyConstraint, PrimaryKeyConstraint, CheckConstraint, UniqueConstraint, Table,
+                        Column)
+from sqlalchemy.util import OrderedDict
+from sqlalchemy.types import Boolean, String
+import sqlalchemy
 
 try:
     from sqlalchemy.sql.expression import TextClause
     from sqlalchemy.sql.expression import _TextClause as TextClause
 
 
-DEFAULT_HEADER = """\
-# coding: utf-8
-from sqlalchemy import *
-from sqlalchemy.orm import *
-from sqlalchemy.ext.declarative import declarative_base
+_re_boolean_check_constraint = re.compile(r"(?:(?:.*?)\.)?(.*?) IN \(0, 1\)")
+_re_enum_check_constraint = re.compile(r"(?:(?:.*?)\.)?(.*?) IN \((.+)\)")
+_re_enum_item = re.compile(r"'(.*?)(?<!\\)'")
 
 
-Base = declarative_base()
+def _singular(plural):
+    """A feeble attempt at converting plural English nouns into singular form."""
+    if plural.endswith('ies'):
+        return plural[:-3] + 'y'
+    if plural.endswith('s') and not plural.endswith('ss'):
+        return plural[:-1]
+    return plural
 
 
-"""
-DEFAULT_FOOTER = ""
+def _tablename_to_classname(tablename):
+    return _singular(''.join(part.capitalize() for part in tablename.split('_')))
 
 
-def get_compiled_expression(statement):
+def _get_compiled_expression(statement):
     """Returns the statement in a form where any placeholders have been filled in."""
     if isinstance(statement, TextClause):
         return statement.text
     return compiler.process(statement)
 
 
-def singular(plural):
-    """A feeble attempt at converting plural English nouns into singular form."""
-    if plural.endswith('ies'):
-        return plural[:-3] + 'y'
-    if plural.endswith('s') and not plural.endswith('ss'):
-        return plural[:-1]
-    return plural
+def _get_constraint_sort_key(constraint):
+    if isinstance(constraint, CheckConstraint):
+        return 'C{0}'.format(constraint.sqltext)
+    return constraint.__class__.__name__[0] + repr(constraint.columns)
 
 
-def get_typename(type_):
-    """Returns the most reasonable column type name to use (ie. String instead of VARCHAR)."""
-    cls = type_.__class__
+def _get_common_fk_constraints(table1, table2):
+    """Returns a set of foreign key constraints the two tables have against each other."""
+    c1 = set(c for c in table1.constraints if isinstance(c, ForeignKeyConstraint) and
+             c.elements[0].column.table == table2)
+    c2 = set(c for c in table2.constraints if isinstance(c, ForeignKeyConstraint) and
+             c.elements[0].column.table == table1)
+    return c1.union(c2)
+
+
+def _render_column_type(coltype):
+    # Figure out the most reasonable column type name to use (ie. String instead of VARCHAR)
+    cls = coltype.__class__
     typename = cls.__class__.__name__
     for supercls in cls.__mro__:
         if hasattr(supercls, '__visit_name__'):
         if supercls.__name__ != supercls.__name__.upper():
             break
 
-    return typename
-
-
-def get_constraint_sort_key(constraint):
-    if isinstance(constraint, CheckConstraint):
-        return 'C{0}'.format(constraint.sqltext)
-    return constraint.__class__.__name__[0] + repr(constraint.columns)
-
-
-def get_common_fk_constraints(table1, table2):
-    """Returns a set of foreign key constraints the two tables have against each other."""
-    c1 = set(c for c in table1.constraints if isinstance(c, ForeignKeyConstraint) and
-             c.elements[0].column.table == table2)
-    c2 = set(c for c in table2.constraints if isinstance(c, ForeignKeyConstraint) and
-             c.elements[0].column.table == table1)
-    return c1.union(c2)
-
-
-def tablename_to_classname(tablename):
-    return singular(''.join(part.capitalize() for part in tablename.split('_')))
-
-
-def classname_to_tablename(classname):
-    tablename = classname[0].lower()
-    for c in classname[1:]:
-        if c.isupper():
-            c = '_' + c.lower()
-        tablename += c
-
-    return tablename
-
-
-def generate_type_code(type_):
-    text = get_typename(type_)
+    text = typename
     args = []
 
-    if isinstance(type_, Enum):
-        args.extend(repr(arg) for arg in type_.enums)
-        if type_.name is not None:
-            args.append('name={0!r}'.format(type_.name))
+    if isinstance(coltype, Enum):
+        args.extend(repr(arg) for arg in coltype.enums)
+        if coltype.name is not None:
+            args.append('name={0!r}'.format(coltype.name))
     else:
         # All other types
-        argspec = inspect.getargspec(type_.__class__.__init__)
+        argspec = inspect.getargspec(coltype.__class__.__init__)
         defaults = dict(zip(argspec.args[-len(argspec.defaults or ()):], argspec.defaults or ()))
         missing = object()
         use_kwargs = False
         for attr in argspec.args[1:]:
-            value = getattr(type_, attr, missing)
+            value = getattr(coltype, attr, missing)
             default = defaults.get(attr, missing)
             if value is missing or value == default:
                 use_kwargs = True
     return text
 
 
-def generate_column_code(column, show_name):
-    kwarg = []
-    is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
-    dedicated_fks = [c for c in column.foreign_keys if len(c.constraint.columns) == 1]
-    is_unique = any(isinstance(c, UniqueConstraint) and set(c.columns) == set([column])
-                    for c in column.table.constraints)
-    has_index = any(set(i.columns) == set([column]) for i in column.table.indexes)
+class ImportCollector(OrderedDict):
+    def add_import(self, obj):
+        type_ = type(obj) if not isinstance(obj, type) else obj
+        pkgname = 'sqlalchemy' if type_.__name__ in sqlalchemy.__all__ else type_.__module__
+        self.add_literal_import(pkgname, type_.__name__)
 
-    if column.key != column.name:
-        kwarg.append('key')
-    if column.primary_key:
-        kwarg.append('primary_key')
-    if not column.nullable and not is_sole_pk:
-        kwarg.append('nullable')
-    if is_unique:
-        column.unique = True
-        kwarg.append('unique')
-    elif has_index:
-        column.index = True
-        kwarg.append('index')
-    if column.server_default and not is_sole_pk:
-        column.server_default = get_compiled_expression(column.server_default.arg)
-        kwarg.append('server_default')
+    def add_literal_import(self, pkgname, name):
+        names = self.setdefault(pkgname, set())
+        names.add(name)
 
-    return 'Column({0})'.format(', '.join(
-        ([repr(column.name)] if show_name else []) +
-        ([generate_type_code(column.type)] if not dedicated_fks else []) +
-        [repr(x) for x in dedicated_fks] +
-        [repr(x) for x in column.constraints] +
-        ['{0}={1}'.format(k, repr(getattr(column, k))) for k in kwarg]))
+    def render(self):
+        return '\n'.join('from {0} import {1}'.format(package, ', '.join(sorted(names)))
+                         for package, names in self.items())
 
 
-def generate_constraint_reprs(constraints):
-    for constraint in sorted(constraints, key=get_constraint_sort_key):
+class Model(object):
+    def __init__(self, table):
+        super(Model, self).__init__()
+        self.table = table
+
+    def add_imports(self, collector):
+        if self.table.columns:
+            collector.add_import(Column)
+
+        for column in self.table.columns:
+            collector.add_import(column.type)
+
+        for constraint in sorted(self.table.constraints, key=_get_constraint_sort_key):
+            if isinstance(constraint, ForeignKeyConstraint):
+                if len(constraint.columns) > 1:
+                    collector.add_literal_import('sqlalchemy', 'ForeignKeyConstraint')
+                else:
+                    collector.add_literal_import('sqlalchemy', 'ForeignKey')
+            elif isinstance(constraint, UniqueConstraint):
+                if len(constraint.columns) > 1:
+                    collector.add_literal_import('sqlalchemy', 'UniqueConstraint')
+            elif not isinstance(constraint, PrimaryKeyConstraint):
+                collector.add_import(constraint)
+
+        for index in self.table.indexes:
+            if len(index.columns) > 1:
+                collector.add_import(index)
+
+    @staticmethod
+    def _render_column(column, show_name):
+        kwarg = []
+        is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
+        dedicated_fks = [c for c in column.foreign_keys if len(c.constraint.columns) == 1]
+        is_unique = any(isinstance(c, UniqueConstraint) and set(c.columns) == set([column])
+                        for c in column.table.constraints)
+        has_index = any(set(i.columns) == set([column]) for i in column.table.indexes)
+
+        if column.key != column.name:
+            kwarg.append('key')
+        if column.primary_key:
+            kwarg.append('primary_key')
+        if not column.nullable and not is_sole_pk:
+            kwarg.append('nullable')
+        if is_unique:
+            column.unique = True
+            kwarg.append('unique')
+        elif has_index:
+            column.index = True
+            kwarg.append('index')
+        if column.server_default and not is_sole_pk:
+            column.server_default = _get_compiled_expression(column.server_default.arg)
+            kwarg.append('server_default')
+
+        return 'Column({0})'.format(', '.join(
+            ([repr(column.name)] if show_name else []) +
+            ([_render_column_type(column.type)] if not dedicated_fks else []) +
+            [repr(x) for x in dedicated_fks] +
+            [repr(x) for x in column.constraints] +
+            ['{0}={1}'.format(k, repr(getattr(column, k))) for k in kwarg]))
+
+    @staticmethod
+    def _render_constraint(constraint):
         if isinstance(constraint, ForeignKeyConstraint):
-            if len(constraint.columns) > 1:
-                local_columns = constraint.columns
-                remote_columns = ['{0}.{1}'.format(fk.column.table.name, fk.column.name) for fk in constraint.elements]
-                yield 'ForeignKeyConstraint({0!r}, {1!r})'.format(local_columns, remote_columns)
+            local_columns = constraint.columns
+            remote_columns = ['{0}.{1}'.format(fk.column.table.name, fk.column.name)
+                              for fk in constraint.elements]
+            return 'ForeignKeyConstraint({0!r}, {1!r})'.format(local_columns, remote_columns)
         elif isinstance(constraint, CheckConstraint):
-            yield 'CheckConstraint({0!r})'.format(get_compiled_expression(constraint.sqltext))
+            return 'CheckConstraint({0!r})'.format(_get_compiled_expression(constraint.sqltext))
         elif isinstance(constraint, UniqueConstraint):
-            if len(constraint.columns) > 1:
-                columns = [repr(col.name) for col in constraint.columns]
-                yield 'UniqueConstraint({0})'.format(', '.join(columns))
+            columns = [repr(col.name) for col in constraint.columns]
+            return 'UniqueConstraint({0})'.format(', '.join(columns))
 
+    @staticmethod
+    def _render_index(index):
+        columns = [repr(col.name) for col in index.columns]
+        return 'Index({0!r}, {1})'.format(index.name, ', '.join(columns))
 
-def generate_index_reprs(indexes):
-    for index in indexes:
-        if len(index.columns) > 1:
-            columns = [repr(col.name) for col in index.columns]
-            yield 'Index({0!r}, {1})'.format(index.name, ', '.join(columns))
 
+class ModelTable(Model):
+    def add_imports(self, collector):
+        super(ModelTable, self).add_imports(collector)
+        collector.add_import(Table)
 
-def generate_table(table):
-    elements = [generate_column_code(column, True) for column in table.c]
-    elements.extend(generate_constraint_reprs(table.constraints))
-    elements.extend(generate_index_reprs(table.indexes))
-    return """\
-t_{0} = Table(
-    {0!r}, Base.metadata,
-    {1}
-)
-""".format(table.name, ',\n    '.join(elements))
+    def render(self):
+        text = 't_{0} = Table(\n    {0!r}, metadata,\n'.format(self.table.name)
 
+        for column in self.table.columns:
+            text += '    {0},\n'.format(self._render_column(column, True))
 
-def generate_relationship_name(colname, remote_classname, used_names):
-    name = base = classname_to_tablename(remote_classname) if not colname.endswith('_id') else colname[:-3]
-    iteration = 0
-    while name in used_names:
-        iteration += 1
-        name = base + str(iteration)
+        for constraint in sorted(self.table.constraints, key=_get_constraint_sort_key):
+            if isinstance(constraint, PrimaryKeyConstraint):
+                continue
+            if isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)) and len(constraint.columns) == 1:
+                continue
+            text += '    {0},\n'.format(self._render_constraint(constraint))
 
-    used_names.add(name)
-    return name
+        for index in self.table.indexes:
+            if len(index.columns) > 1:
+                text += '    {0},\n'.format(self._render_index(index))
 
+        return text.rstrip('\n,') + '\n)'
 
-def generate_relationship(classname, used_names, fk_constraint=None, link_table=None):
-    remote_side = uselist = secondary = primaryjoin = secondaryjoin = None
-    if link_table is not None:
-        # Many-to-Many
-        secondary = 'secondary=' + repr(link_table.name)
-        fk_constraints = [c for c in link_table.constraints if isinstance(c, ForeignKeyConstraint)]
-        fk_constraints.sort(key=get_constraint_sort_key)
-        remote_tablename = fk_constraints[1].elements[0].column.table.name
-        remote_classname = tablename_to_classname(remote_tablename)
-        relationship_name = generate_relationship_name(fk_constraints[1].columns[0], remote_classname, used_names) + 's'
 
-        # Handle self referential relationships
-        if classname == remote_classname:
-            pri_pairs = zip(fk_constraints[0].columns, fk_constraints[0].elements)
-            sec_pairs = zip(fk_constraints[1].columns, fk_constraints[1].elements)
-            pri_joins = ['{0}.{1} == {2}.c.{3}'.format(classname, elem.column.name, link_table.name, col)
-                         for col, elem in pri_pairs]
-            sec_joins = ['{0}.{1} == {2}.c.{3}'.format(classname, elem.column.name, link_table.name, col)
-                         for col, elem in sec_pairs]
-            primaryjoin = 'primaryjoin=' + (
-                repr('and_({0})'.format(', '.join(pri_joins))) if len(pri_joins) > 1 else repr(pri_joins[0]))
-            secondaryjoin = 'secondaryjoin=' + (
-                repr('and_({0})'.format(', '.join(sec_joins))) if len(sec_joins) > 1 else repr(sec_joins[0]))
-    else:
-        # One-to-Many or One-to-One
-        remote_classname = tablename_to_classname(fk_constraint.elements[0].column.table.name)
-        relationship_name = generate_relationship_name(fk_constraint.columns[0], remote_classname, used_names)
+class ModelClass(Model):
+    def __init__(self, table, association_tables):
+        super(ModelClass, self).__init__(table)
+        self.name = _tablename_to_classname(table.name)
+        self.children = []
+        self.attributes = OrderedDict()
+
+        # Assign attribute names for columns
+        for column in table.columns:
+            attrname = column.name + '_' if iskeyword(column.name) else column.name
+            self.attributes[attrname] = column
+
+        # Add many-to-one relationships
+        for constraint in sorted(table.constraints, key=_get_constraint_sort_key):
+            if isinstance(constraint, ForeignKeyConstraint):
+                target_cls = _tablename_to_classname(constraint.elements[0].column.table.name)
+                self._add_relationship(ManyToOneRelationship(self.name, target_cls, constraint))
+
+        # Add many-to-many relationships
+        for association_table in association_tables:
+            fk_constraints = [c for c in association_table.constraints if isinstance(c, ForeignKeyConstraint)]
+            fk_constraints.sort(key=_get_constraint_sort_key)
+            target_cls = _tablename_to_classname(fk_constraints[1].elements[0].column.table.name)
+            self._add_relationship(ManyToManyRelationship(self.name, target_cls, association_table))
+
+    def _add_relationship(self, relationship):
+        for attrname in relationship.suggested_names:
+            if attrname not in self.attributes and not iskeyword(attrname):
+                self.attributes[attrname] = relationship
+                break
+
+    def add_imports(self, collector):
+        super(ModelClass, self).add_imports(collector)
+
+        if any(isinstance(value, Relationship) for value in self.attributes.values()):
+            collector.add_literal_import('sqlalchemy.orm', 'relationship')
+
+        for child in self.children:
+            child.add_imports(collector)
+
+    def render(self, parentname='Base'):
+        text = 'class {0}({1}):\n'.format(self.name, parentname)
+        text += '    __tablename__ = {0!r}\n'.format(self.table.name)
+
+        table_args = []
+        for constraint in sorted(self.table.constraints, key=_get_constraint_sort_key):
+            if isinstance(constraint, PrimaryKeyConstraint):
+                continue
+            if isinstance(constraint, (ForeignKeyConstraint, UniqueConstraint)) and len(constraint.columns) == 1:
+                continue
+            table_args.append(self._render_constraint(constraint))
+        for index in self.table.indexes:
+            if len(index.columns) > 1:
+                table_args.append(self._render_index(index))
+        if table_args:
+            if len(table_args) == 1:
+                table_args[0] += ','
+            text += '    __table_args__ = (\n        {0}\n    )\n'.format(',\n        '.join(table_args))
+
+        text += '\n'
+        for attr, column in self.attributes.items():
+            if isinstance(column, Column):
+                show_name = attr != column.name
+                text += '    {0} = {1}\n'.format(attr, self._render_column(column, show_name))
+
+        if any(isinstance(value, Relationship) for value in self.attributes.values()):
+            text += '\n'
+        for attr, relationship in self.attributes.items():
+            if isinstance(relationship, Relationship):
+                text += '    {0} = {1}\n'.format(attr, relationship.render())
+
+        for child_class in self.children:
+            text += '\n' + child_class.render(self.classname)
+
+        return text
+
+
+class Relationship(object):
+    def __init__(self, source_cls, target_cls):
+        super(Relationship, self).__init__()
+        self.source_cls = source_cls
+        self.target_cls = target_cls
+        self.kwargs = OrderedDict()
+
+    @property
+    def suggested_names(self):
+        yield self.preferred_name if not iskeyword(self.preferred_name) else self.preferred_name + '_'
+
+        iteration = 0
+        while True:
+            iteration += 1
+            yield self.preferred_name + str(iteration)
+
+    def render(self):
+        text = 'relationship('
+        args = [repr(self.target_cls)]
+
+        if 'secondaryjoin' in self.kwargs:
+            text += '\n        '
+            delimiter, end = ',\n        ', '\n    )'
+        else:
+            delimiter, end = ', ', ')'
+
+        args.extend([key + '=' + value for key, value in self.kwargs.items()])
+        return text + delimiter.join(args) + end
+
+
+class ManyToOneRelationship(Relationship):
+    def __init__(self, source_cls, target_cls, constraint):
+        super(ManyToOneRelationship, self).__init__(source_cls, target_cls)
+
+        colname = constraint.columns[0]
+        tablename = constraint.elements[0].column.table.name
+        self.preferred_name = _singular(tablename) if not colname.endswith('_id') else colname[:-3]
 
         # Add uselist=False to One-to-One relationships
         if any(isinstance(c, (PrimaryKeyConstraint, UniqueConstraint)) and
-               set(col.name for col in c.columns) == set(fk_constraint.columns)
-               for c in fk_constraint.table.constraints):
-            uselist = 'uselist=False'
+               set(col.name for col in c.columns) == set(constraint.columns)
+               for c in constraint.table.constraints):
+            self.kwargs['uselist'] = 'False'
 
         # Handle self referential relationships
-        if classname == remote_classname:
-            pk_col_names = [col.name for col in fk_constraint.table.primary_key]
-            remote_side = 'remote_side=[{0}]'.format(', '.join(pk_col_names))
+        if source_cls == target_cls:
+            self.preferred_name = 'parent' if not colname.endswith('_id') else colname[:-3]
+            pk_col_names = [col.name for col in constraint.table.primary_key]
+            self.kwargs['remote_side'] = '[{0}]'.format(', '.join(pk_col_names))
 
         # If the two tables share more than one foreign key constraint,
         # SQLAlchemy needs an explicit primaryjoin to figure out which column(s) to join with
-        common_fk_constraints = get_common_fk_constraints(fk_constraint.table, fk_constraint.elements[0].column.table)
+        common_fk_constraints = _get_common_fk_constraints(constraint.table, constraint.elements[0].column.table)
         if len(common_fk_constraints) > 1:
-            primaryjoin = "primaryjoin='{0}.{1} == {2}.{3}'".format(classname, fk_constraint.columns[0],
-                                                                    remote_classname,
-                                                                    fk_constraint.elements[0].column.name)
+            self.kwargs['primaryjoin'] = "'{0}.{1} == {2}.{3}'".format(source_cls, constraint.columns[0], target_cls,
+                                                                       constraint.elements[0].column.name)
 
-    args = [arg for arg in (repr(remote_classname), remote_side, uselist, secondary, primaryjoin, secondaryjoin) if arg]
-    if secondaryjoin:
-        return '{0} = relationship(\n        {1}\n    )'.format(relationship_name, ',\n        '.join(args))
-    else:
-        return '{0} = relationship({1})'.format(relationship_name, ', '.join(args))
 
+class ManyToManyRelationship(Relationship):
+    def __init__(self, source_cls, target_cls, assocation_table):
+        super(ManyToManyRelationship, self).__init__(source_cls, target_cls)
 
-def generate_class(table, links=()):
-    used_names = set()
-    classname = tablename_to_classname(table.name)
-    text = 'class {0}(Base):\n    __tablename__ = {1!r}\n'.format(classname, table.name)
+        self.kwargs['secondary'] = repr(assocation_table.name)
+        constraints = [c for c in assocation_table.constraints if isinstance(c, ForeignKeyConstraint)]
+        constraints.sort(key=_get_constraint_sort_key)
+        colname = constraints[1].columns[0]
+        tablename = constraints[1].elements[0].column.table.name
+        self.preferred_name = tablename if not colname.endswith('_id') else colname[:-3] + 's'
 
-    constraints = sorted(table.constraints, key=get_constraint_sort_key)
-    table_args = list(generate_constraint_reprs(constraints)) + list(generate_index_reprs(table.indexes))
-    if table_args:
-        if len(table_args) == 1:
-            table_args[0] += ','  # Required for this to be a tuple
-        text += '    __table_args__ = (\n        {0}\n    )\n'.format(',\n        '.join(table_args))
-    text += '\n'
+        # Handle self referential relationships
+        if source_cls == target_cls:
+            self.preferred_name = 'parents' if not colname.endswith('_id') else colname[:-3] + 's'
+            pri_pairs = zip(constraints[0].columns, constraints[0].elements)
+            sec_pairs = zip(constraints[1].columns, constraints[1].elements)
+            pri_joins = ['{0}.{1} == {2}.c.{3}'.format(source_cls, elem.column.name, assocation_table.name, col)
+                         for col, elem in pri_pairs]
+            sec_joins = ['{0}.{1} == {2}.c.{3}'.format(target_cls, elem.column.name, assocation_table.name, col)
+                         for col, elem in sec_pairs]
+            self.kwargs['primaryjoin'] = (
+                repr('and_({0})'.format(', '.join(pri_joins))) if len(pri_joins) > 1 else repr(pri_joins[0]))
+            self.kwargs['secondaryjoin'] = (
+                repr('and_({0})'.format(', '.join(sec_joins))) if len(sec_joins) > 1 else repr(sec_joins[0]))
 
-    # Generate columns
-    for column in table.c:
-        attrname = column.name + '_' if iskeyword(column.name) else column.name
-        used_names.add(attrname)
-        col_repr = generate_column_code(column, attrname != column.name)
-        text += '    {0} = {1}\n'.format(attrname, col_repr)
 
-    if links or table.foreign_keys:
-        text += '\n'
+class CodeGenerator(object):
+    header = '# coding: utf-8'
+    footer = ''
 
-    # Generate many-to-one relationships
-    for constraint in constraints:
-        if isinstance(constraint, ForeignKeyConstraint):
-            relationship = generate_relationship(classname, used_names, fk_constraint=constraint)
-            text += '    {0}\n'.format(relationship)
+    def __init__(self, metadata, noindexes=False, noconstraints=False):
+        super(CodeGenerator, self).__init__()
+        self.models = []
+        self.collector = ImportCollector()
 
-    # Generate many-to-many relationships
-    for link_table in links:
-        relationship = generate_relationship(classname, used_names, link_table=link_table)
-        text += '    {0}\n'.format(relationship)
+        # Pick association tables from the metadata into their own set, don't process them normally
+        links = defaultdict(lambda: [])
+        association_tables = set()
+        for table in metadata.tables.values():
+            # Link tables have exactly two foreign key constraints and all columns are involved in them
+            fk_constraints = [constr for constr in table.constraints if isinstance(constr, ForeignKeyConstraint)]
+            if (len(fk_constraints) == 2 and all(col.foreign_keys for col in table.columns)):
+                association_tables.add(table.name)
+                tablename = sorted(fk_constraints, key=_get_constraint_sort_key)[0].elements[0].column.table.name
+                links[tablename].append(table)
 
-    return text
+        # Iterate through the tables and create model classes when possible
+        for table in sorted(metadata.tables.values(), key=lambda t: t.name):
+            # Support for Alembic and sqlalchemy-migrate -- never expose the schema version tables
+            if table.name in ('alembic_version', 'migrate_version'):
+                continue
 
+            if noindexes:
+                table.indexes.clear()
 
-def generate_declarative_models(metadata, noindexes=False, noconstraints=False):
-    links = defaultdict(lambda: [])
-    link_tables = set()
-    for table in metadata.tables.values():
-        # Link tables have exactly two foreign key constraints and all columns are involved in them
-        fk_constraints = [constr for constr in table.constraints if isinstance(constr, ForeignKeyConstraint)]
-        if (len(fk_constraints) == 2 and all(col.foreign_keys for col in table.columns)):
-            link_tables.add(table.name)
-            tablename = sorted(fk_constraints, key=get_constraint_sort_key)[0].elements[0].column.table.name
-            links[tablename].append(table)
+            if noconstraints:
+                table.constraints = set([table.primary_key])
+                table.foreign_keys.clear()
+                for col in table.columns:
+                    col.foreign_keys.clear()
+            else:
+                # Detect check constraints for boolean and enum columns
+                for constraint in table.constraints.copy():
+                    if isinstance(constraint, CheckConstraint):
+                        sqltext = _get_compiled_expression(constraint.sqltext)
+                        match = _re_boolean_check_constraint.match(sqltext)
+                        if match:
+                            colname = match.group(1)
+                            table.constraints.remove(constraint)
+                            table.c[colname].type = Boolean()
+                            continue
 
-    for table in sorted(metadata.tables.values(), key=lambda t: t.name):
-        # Support for Alembic and sqlalchemy-migrate -- never expose the schema version tables
-        if table.name in ('alembic_version', 'migrate_version'):
-            continue
+                        match = _re_enum_check_constraint.match(sqltext)
+                        if match:
+                            colname = match.group(1)
+                            items = match.group(2)
+                            if isinstance(table.c[colname].type, String):
+                                table.constraints.remove(constraint)
+                                if not isinstance(table.c[colname].type, Enum):
+                                    options = _re_enum_item.findall(items)
+                                    table.c[colname].type = Enum(*options, native_enum=False)
+                                continue
 
-        if noindexes:
-            table.indexes.clear()
+            # Only form model classes for tables that have a primary key and are not association tables
+            if not table.primary_key or table.name in association_tables:
+                model = ModelTable(table)
+            else:
+                model = ModelClass(table, links[table.name])
 
-        if noconstraints:
-            table.constraints = set([table.primary_key])
-            table.foreign_keys.clear()
-            for col in table.columns:
-                col.foreign_keys.clear()
+            self.models.append(model)
+            model.add_imports(self.collector)
 
-        if not table.primary_key or table.name in link_tables:
-            yield generate_table(table)
+        if not any(isinstance(model, ModelClass) for model in self.models):
+            self.collector.add_literal_import('sqlalchemy', 'MetaData')
         else:
-            yield generate_class(table, links[table.name])
+            self.collector.add_literal_import('sqlalchemy.ext.declarative', 'declarative_base')
 
+    def render(self, outfile=sys.stdout):
+        print(self.header, file=outfile)
 
-def generate_model_code(metadata, noindexes, noconstraints, header=DEFAULT_HEADER, footer=DEFAULT_FOOTER):
-    models = generate_declarative_models(metadata, noindexes, noconstraints)
-    return header + '\n\n'.join(models).rstrip() + footer
+        # Render the collected imports
+        print(self.collector.render() + '\n\n', file=outfile)
+
+        if any(isinstance(model, ModelClass) for model in self.models):
+            print('Base = declarative_base()\nmetadata = Base.metadata', file=outfile)
+        else:
+            print('metadata = MetaData()', file=outfile)
+
+        # Render the model tables and classes
+        for model in self.models:
+            print('\n\n' + model.render().rstrip('\n'), file=outfile)
+
+        if self.footer:
+            print(self.footer, file=outfile)

sqlacodegen/main.py

 from sqlalchemy.engine import create_engine
 from sqlalchemy.schema import MetaData
 
-from sqlacodegen.codegen import generate_model_code
+from sqlacodegen.codegen import CodeGenerator
 
 
 def main():
     metadata = MetaData(engine)
     tables = args.tables.split(',') if args.tables else None
     metadata.reflect(engine, args.schema, not args.noviews, tables)
-    print(generate_model_code(metadata, args.noindexes, args.noconstraints), file=args.outfile)
+    generator = CodeGenerator(metadata, args.noindexes, args.noconstraints)
+    generator.render(args.outfile)

test/test_codegen.py

 from __future__ import unicode_literals, division, print_function, absolute_import
+from io import StringIO
 import sys  # @UnusedImport
 import re
 
 from nose.tools import eq_
 from sqlalchemy import *
 
-from sqlacodegen.codegen import (generate_declarative_models, generate_type_code, singular, generate_relationship_name,
-                                 generate_table, generate_class)
+from sqlacodegen.codegen import CodeGenerator, _singular
 
 
 if sys.version_info < (3,):
-    unicode_re = re.compile(r"u'([^']*)'")
+    unicode_re = re.compile(r"u('|\")(.*?)(?<!\\)\1")
 
     def remove_unicode_prefixes(text):
-        return unicode_re.sub(r"'\1'", text)
+        return unicode_re.sub(r"\1\2\1", text)
 else:
     remove_unicode_prefixes = lambda text: text
 
 
 def test_singular_ies():
-    eq_(singular('bunnies'), 'bunny')
+    eq_(_singular('bunnies'), 'bunny')
 
 
 def test_singular_ss():
-    eq_(singular('address'), 'address')
+    eq_(_singular('address'), 'address')
 
 
-def test_generate_relationship_id():
-    eq_(generate_relationship_name('item_id', 'Item', set()), 'item')
+class TestCodeGenerator(object):
+    def generate_code(self, metadata, **kwargs):
+        codegen = CodeGenerator(metadata, **kwargs)
+        sio = StringIO()
+        codegen.render(sio)
+        return remove_unicode_prefixes(sio.getvalue())
 
+    def test_fancy_coltypes(self):
+        testmeta = MetaData(create_engine('sqlite:///'))
+        Table(
+            'simple_items', testmeta,
+            Column('enum', Enum('A', 'B', name='blah')),
+            Column('bool', Boolean),
+            Column('number', Numeric(10, asdecimal=False)),
+        )
 
-def test_generate_relationship_multi():
-    used_names = set()
-    eq_(generate_relationship_name('fk1', 'RemoteClass', used_names), 'remote_class')
-    eq_(generate_relationship_name('fk2', 'RemoteClass', used_names), 'remote_class1')
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Boolean, Column, Enum, MetaData, Numeric, Table
 
 
-def test_typecode_plain():
-    eq_(generate_type_code(Integer()), 'Integer')
+metadata = MetaData()
 
 
-def test_typecode_arg():
-    eq_(generate_type_code(String(20)), 'String(20)')
+t_simple_items = Table(
+    'simple_items', metadata,
+    Column('enum', Enum('A', 'B', name='blah')),
+    Column('bool', Boolean),
+    Column('number', Numeric(10, asdecimal=False))
+)
+""")
 
+    def test_boolean_detection(self):
+        testmeta = MetaData(create_engine('sqlite:///'))
+        Table(
+            'simple_items', testmeta,
+            Column('bool', SmallInteger),
+            CheckConstraint('simple_items.bool IN (0, 1)')
+        )
 
-def test_typecode_kwarg():
-    typecode = generate_type_code(Numeric(10, asdecimal=False))
-    eq_(typecode, 'Numeric(10, asdecimal=False)')
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Boolean, Column, MetaData, Table
 
 
-def test_typecode_enum():
-    typecode = generate_type_code(Enum('A', 'B', name='blah'))
-    eq_(remove_unicode_prefixes(typecode), "Enum('A', 'B', name='blah')")
+metadata = MetaData()
 
 
-def test_constraints_table():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('number', Integer),
-        CheckConstraint('number > 0'),
-        UniqueConstraint('id', 'number')
-    )
+t_simple_items = Table(
+    'simple_items', metadata,
+    Column('bool', Boolean)
+)
+""")
 
-    table_def = remove_unicode_prefixes(generate_table(simple_items))
-    eq_(table_def, """\
+    def test_enum_detection(self):
+        testmeta = MetaData(create_engine('sqlite:///'))
+        Table(
+            'simple_items', testmeta,
+            Column('enum', String),
+            CheckConstraint(r"simple_items.enum IN ('A', '\'B', 'C')")
+        )
+
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, Enum, MetaData, Table
+
+
+metadata = MetaData()
+
+
 t_simple_items = Table(
-    'simple_items', Base.metadata,
-    Column('id', Integer, primary_key=True),
+    'simple_items', metadata,
+    Column('enum', Enum('A', "\\\\'B", 'C'))
+)
+""")
+
+    def test_constraints_table(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer),
+            Column('number', Integer),
+            CheckConstraint('number > 0'),
+            UniqueConstraint('id', 'number')
+        )
+
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table, UniqueConstraint
+
+
+metadata = MetaData()
+
+
+t_simple_items = Table(
+    'simple_items', metadata,
+    Column('id', Integer),
     Column('number', Integer),
     CheckConstraint('number > 0'),
     UniqueConstraint('id', 'number')
 )
 """)
 
+    def test_constraints_class(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('number', Integer),
+            CheckConstraint('number > 0'),
+            UniqueConstraint('id', 'number')
+        )
 
-def test_constraints_class():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('number', Integer),
-        CheckConstraint('number > 0'),
-        UniqueConstraint('id', 'number')
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import CheckConstraint, Column, Integer, UniqueConstraint
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_def = remove_unicode_prefixes(generate_class(simple_items))
-    eq_(table_def, """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
     __table_args__ = (
     number = Column(Integer)
 """)
 
+    def test_noindexes_table(self):
+        testmeta = MetaData()
+        simple_items = Table(
+            'simple_items', testmeta,
+            Column('number', Integer),
+            CheckConstraint('number > 2')
+        )
+        simple_items.indexes.add(Index('idx_number', simple_items.c.number))
 
-def test_noindexes_table():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('number', Integer),
-        CheckConstraint('number > 2')
-    )
-    simple_items.indexes.add(Index('idx_number', simple_items.c.number))
+        eq_(self.generate_code(testmeta, noindexes=True), """\
+# coding: utf-8
+from sqlalchemy import CheckConstraint, Column, Integer, MetaData, Table
 
-    table_def = remove_unicode_prefixes(next(generate_declarative_models(testmeta, noindexes=True)))
-    eq_(table_def, """\
+
+metadata = MetaData()
+
+
 t_simple_items = Table(
-    'simple_items', Base.metadata,
+    'simple_items', metadata,
     Column('number', Integer),
     CheckConstraint('number > 2')
 )
 """)
 
+    def test_noconstraints_table(self):
+        testmeta = MetaData()
+        simple_items = Table(
+            'simple_items', testmeta,
+            Column('number', Integer),
+            CheckConstraint('number > 2')
+        )
+        simple_items.indexes.add(Index('idx_number', simple_items.c.number))
 
-def test_noconstraints_table():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('number', Integer),
-        CheckConstraint('number > 2')
-    )
-    simple_items.indexes.add(Index('idx_number', simple_items.c.number))
+        eq_(self.generate_code(testmeta, noconstraints=True), """\
+# coding: utf-8
+from sqlalchemy import Column, Integer, MetaData, Table
 
-    table_def = remove_unicode_prefixes(next(generate_declarative_models(testmeta, noconstraints=True)))
-    eq_(table_def, """\
+
+metadata = MetaData()
+
+
 t_simple_items = Table(
-    'simple_items', Base.metadata,
+    'simple_items', metadata,
     Column('number', Integer, index=True)
 )
 """)
 
+    def test_indexes_table(self):
+        testmeta = MetaData()
+        simple_items = Table(
+            'simple_items', testmeta,
+            Column('id', Integer),
+            Column('number', Integer),
+            Column('text', String)
+        )
+        simple_items.indexes.add(Index('idx_number', simple_items.c.number))
+        simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number))
 
-def test_indexes_table():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('number', Integer),
-        Column('text', String)
-    )
-    simple_items.indexes.add(Index('idx_number', simple_items.c.number))
-    simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number))
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, Index, Integer, MetaData, String, Table
 
-    table_def = remove_unicode_prefixes(generate_table(simple_items))
-    eq_(table_def, """\
+
+metadata = MetaData()
+
+
 t_simple_items = Table(
-    'simple_items', Base.metadata,
-    Column('id', Integer, primary_key=True),
+    'simple_items', metadata,
+    Column('id', Integer),
     Column('number', Integer, index=True),
     Column('text', String),
     Index('idx_text_number', 'text', 'number')
 )
 """)
 
+    def test_indexes_class(self):
+        testmeta = MetaData()
+        simple_items = Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('number', Integer),
+            Column('text', String)
+        )
+        simple_items.indexes.add(Index('idx_number', simple_items.c.number))
+        simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number))
 
-def test_indexes_class():
-    testmeta = MetaData()
-    simple_items = Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('number', Integer),
-        Column('text', String)
-    )
-    simple_items.indexes.add(Index('idx_number', simple_items.c.number))
-    simple_items.indexes.add(Index('idx_text_number', simple_items.c.text, simple_items.c.number))
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, Index, Integer, String
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_def = remove_unicode_prefixes(generate_class(simple_items))
-    eq_(table_def, """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
     __table_args__ = (
     text = Column(String)
 """)
 
+    def test_onetomany(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('container_id', Integer),
+            ForeignKeyConstraint(['container_id'], ['simple_containers.id']),
+        )
+        Table(
+            'simple_containers', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
 
-def test_onetomany():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('container_id', Integer),
-        ForeignKeyConstraint(['container_id'], ['simple_containers.id']),
-    )
-    Table(
-        'simple_containers', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 2)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleContainer(Base):
     __tablename__ = 'simple_containers'
 
     id = Column(Integer, primary_key=True)
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     container = relationship('SimpleContainer')
 """)
 
+    def test_onetomany_selfref(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('parent_item_id', Integer),
+            ForeignKeyConstraint(['parent_item_id'], ['simple_items.id'])
+        )
 
-def test_onetomany_selfref():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('parent_item_id', Integer),
-        ForeignKeyConstraint(['parent_item_id'], ['simple_items.id'])
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    eq_(len(table_defs), 1)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     parent_item = relationship('SimpleItem', remote_side=[id])
 """)
 
+    def test_onetomany_selfref_multi(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('parent_item_id', Integer),
+            Column('top_item_id', Integer),
+            ForeignKeyConstraint(['parent_item_id'], ['simple_items.id']),
+            ForeignKeyConstraint(['top_item_id'], ['simple_items.id'])
+        )
 
-def test_onetomany_selfref_multi():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('parent_item_id', Integer),
-        Column('top_item_id', Integer),
-        ForeignKeyConstraint(['parent_item_id'], ['simple_items.id']),
-        ForeignKeyConstraint(['top_item_id'], ['simple_items.id'])
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    eq_(len(table_defs), 1)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     top_item = relationship('SimpleItem', remote_side=[id], primaryjoin='SimpleItem.top_item_id == SimpleItem.id')
 """)
 
+    def test_onetomany_composite(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('container_id1', Integer),
+            Column('container_id2', Integer),
+            ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2'])
+        )
+        Table(
+            'simple_containers', testmeta,
+            Column('id1', Integer, primary_key=True),
+            Column('id2', Integer, primary_key=True)
+        )
 
-def test_onetomany_composite():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('container_id1', Integer),
-        Column('container_id2', Integer),
-        ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2'])
-    )
-    Table(
-        'simple_containers', testmeta,
-        Column('id1', Integer, primary_key=True),
-        Column('id2', Integer, primary_key=True)
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKeyConstraint, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 2)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleContainer(Base):
     __tablename__ = 'simple_containers'
 
     id1 = Column(Integer, primary_key=True, nullable=False)
     id2 = Column(Integer, primary_key=True, nullable=False)
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
     __table_args__ = (
     simple_container = relationship('SimpleContainer')
 """)
 
+    def test_onetomany_multiref(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('parent_container_id', Integer),
+            Column('top_container_id', Integer),
+            ForeignKeyConstraint(['parent_container_id'], ['simple_containers.id']),
+            ForeignKeyConstraint(['top_container_id'], ['simple_containers.id'])
+        )
+        Table(
+            'simple_containers', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
 
-def test_onetomany_multiref():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('parent_container_id', Integer),
-        Column('top_container_id', Integer),
-        ForeignKeyConstraint(['parent_container_id'], ['simple_containers.id']),
-        ForeignKeyConstraint(['top_container_id'], ['simple_containers.id'])
-    )
-    Table(
-        'simple_containers', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 2)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class SimpleContainer(Base):
     __tablename__ = 'simple_containers'
 
     id = Column(Integer, primary_key=True)
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     top_container = relationship('SimpleContainer', primaryjoin='SimpleItem.top_container_id == SimpleContainer.id')
 """)
 
+    def test_onetoone(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True),
+            Column('other_item_id', Integer),
+            ForeignKeyConstraint(['other_item_id'], ['other_items.id']),
+            UniqueConstraint('other_item_id')
+        )
+        Table(
+            'other_items', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
 
-def test_onetoone():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True),
-        Column('other_item_id', Integer),
-        ForeignKeyConstraint(['other_item_id'], ['other_items.id']),
-        UniqueConstraint('other_item_id')
-    )
-    Table(
-        'other_items', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 2)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
 class OtherItem(Base):
     __tablename__ = 'other_items'
 
     id = Column(Integer, primary_key=True)
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     other_item = relationship('OtherItem', uselist=False)
 """)
 
+    def test_manytomany(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
+        Table(
+            'simple_containers', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
+        Table(
+            'container_items', testmeta,
+            Column('item_id', Integer),
+            Column('container_id', Integer),
+            ForeignKeyConstraint(['item_id'], ['simple_items.id']),
+            ForeignKeyConstraint(['container_id'], ['simple_containers.id'])
+        )
 
-def test_manytomany():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
-    Table(
-        'simple_containers', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
-    Table(
-        'container_items', testmeta,
-        Column('item_id', Integer),
-        Column('container_id', Integer),
-        ForeignKeyConstraint(['item_id'], ['simple_items.id']),
-        ForeignKeyConstraint(['container_id'], ['simple_containers.id'])
-    )
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer, Table
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
 
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 3)
-    eq_(table_defs[0], """\
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+t_container_items = Table(
+    'container_items', metadata,
+    Column('item_id', ForeignKey('simple_items.id')),
+    Column('container_id', ForeignKey('simple_containers.id'))
+)
+
+
 class SimpleContainer(Base):
     __tablename__ = 'simple_containers'
 
     id = Column(Integer, primary_key=True)
 
     items = relationship('SimpleItem', secondary='container_items')
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     id = Column(Integer, primary_key=True)
 """)
-    eq_(table_defs[2], """\
-t_container_items = Table(
-    'container_items', Base.metadata,
-    Column('item_id', ForeignKey('simple_items.id')),
-    Column('container_id', ForeignKey('simple_containers.id'))
+
+    def test_manytomany_selfref(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id', Integer, primary_key=True)
+        )
+        Table(
+            'child_items', testmeta,
+            Column('parent_id', Integer),
+            Column('child_id', Integer),
+            ForeignKeyConstraint(['parent_id'], ['simple_items.id']),
+            ForeignKeyConstraint(['child_id'], ['simple_items.id'])
+        )
+
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKey, Integer, Table
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
+
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+t_child_items = Table(
+    'child_items', metadata,
+    Column('parent_id', ForeignKey('simple_items.id')),
+    Column('child_id', ForeignKey('simple_items.id'))
 )
-""")
 
 
-def test_manytomany_selfref():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id', Integer, primary_key=True)
-    )
-    Table(
-        'child_items', testmeta,
-        Column('parent_id', Integer),
-        Column('child_id', Integer),
-        ForeignKeyConstraint(['parent_id'], ['simple_items.id']),
-        ForeignKeyConstraint(['child_id'], ['simple_items.id'])
-    )
-
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 2)
-    eq_(table_defs[0], """\
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
         secondaryjoin='SimpleItem.id == child_items.c.parent_id'
     )
 """)
-    eq_(table_defs[1], """\
-t_child_items = Table(
-    'child_items', Base.metadata,
-    Column('parent_id', ForeignKey('simple_items.id')),
-    Column('child_id', ForeignKey('simple_items.id'))
+
+    def test_manytomany_composite(self):
+        testmeta = MetaData()
+        Table(
+            'simple_items', testmeta,
+            Column('id1', Integer, primary_key=True),
+            Column('id2', Integer, primary_key=True)
+        )
+        Table(
+            'simple_containers', testmeta,
+            Column('id1', Integer, primary_key=True),
+            Column('id2', Integer, primary_key=True)
+        )
+        Table(
+            'container_items', testmeta,
+            Column('item_id1', Integer),
+            Column('item_id2', Integer),
+            Column('container_id1', Integer),
+            Column('container_id2', Integer),
+            ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', 'simple_items.id2']),
+            ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2'])
+        )
+
+        eq_(self.generate_code(testmeta), """\
+# coding: utf-8
+from sqlalchemy import Column, ForeignKeyConstraint, Integer, Table
+from sqlalchemy.orm import relationship
+from sqlalchemy.ext.declarative import declarative_base
+
+
+Base = declarative_base()
+metadata = Base.metadata
+
+
+t_container_items = Table(
+    'container_items', metadata,
+    Column('item_id1', Integer),
+    Column('item_id2', Integer),
+    Column('container_id1', Integer),
+    Column('container_id2', Integer),
+    ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2']),
+    ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', 'simple_items.id2'])
 )
-""")
 
 
-def test_manytomany_composite():
-    testmeta = MetaData()
-    Table(
-        'simple_items', testmeta,
-        Column('id1', Integer, primary_key=True),
-        Column('id2', Integer, primary_key=True)
-    )
-    Table(
-        'simple_containers', testmeta,
-        Column('id1', Integer, primary_key=True),
-        Column('id2', Integer, primary_key=True)
-    )
-    Table(
-        'container_items', testmeta,
-        Column('item_id1', Integer),
-        Column('item_id2', Integer),
-        Column('container_id1', Integer),
-        Column('container_id2', Integer),
-        ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', 'simple_items.id2']),
-        ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2'])
-    )
-
-    table_defs = [remove_unicode_prefixes(table_def) for table_def in generate_declarative_models(testmeta)]
-    table_defs.sort()
-    eq_(len(table_defs), 3)
-    eq_(table_defs[0], """\
 class SimpleContainer(Base):
     __tablename__ = 'simple_containers'
 
     id2 = Column(Integer, primary_key=True, nullable=False)
 
     simple_items = relationship('SimpleItem', secondary='container_items')
-""")
-    eq_(table_defs[1], """\
+
+
 class SimpleItem(Base):
     __tablename__ = 'simple_items'
 
     id1 = Column(Integer, primary_key=True, nullable=False)
     id2 = Column(Integer, primary_key=True, nullable=False)
 """)
-    eq_(table_defs[2], """\
-t_container_items = Table(
-    'container_items', Base.metadata,
-    Column('item_id1', Integer),
-    Column('item_id2', Integer),
-    Column('container_id1', Integer),
-    Column('container_id2', Integer),
-    ForeignKeyConstraint(['container_id1', 'container_id2'], ['simple_containers.id1', 'simple_containers.id2']),
-    ForeignKeyConstraint(['item_id1', 'item_id2'], ['simple_items.id1', 'simple_items.id2'])
-)
-""")
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.