Wiki

Clone wiki

sqlalchemy / UsageRecipes / AutoRelationships

AutoRelationships

Illustrates how to generate a relationship() on a mapped class by inspecting the foreign key constraints on the table.

This is a rudimentary example; the general idea here can (and should) be extrapolated into a much nicer system to suit actual usage scenarios.

from sqlalchemy.orm import relationship, backref
from sqlalchemy import ForeignKeyConstraint

def camelcase(name):
    return (
            name[0].upper() +
            name[1:]
        )

def map_metadata(Base, metadata):
    result = {}
    table_to_class = {}
    for table in metadata.tables.values():
        clsname = camelcase(table.name).encode('ascii')
        cls = type(clsname, (Base,), {"name": clsname, "__table__": table})
        result[clsname] = cls
        table_to_class[table] = cls

    for table in metadata.tables.values():
        local_cls = table_to_class[table]
        for constraint in table.constraints:
            if isinstance(constraint, ForeignKeyConstraint):
                fks = constraint.elements
                referred_table = fks[0].column.table
                referred_cls = table_to_class[referred_table]

                setattr(
                    local_cls,
                    referred_cls.__name__.lower(),
                    relationship(referred_cls,
                            foreign_keys=[fk.parent for fk in constraint.elements],
                            backref=backref(
                                    local_cls.__name__.lower() + "_collection",
                                    passive_deletes="all",  # optional, replace this 
                                                            # with desired cascading behavior
                                )
                            )
                    )
    return result




if __name__ == '__main__':
    from sqlalchemy import create_engine, MetaData, Table, Column, Integer, ForeignKey

    metadata = MetaData()

    a = Table('a', metadata,
            Column('id', Integer, primary_key=True),
            Column('did', ForeignKey('d.id', ondelete="CASCADE"))
        )
    b = Table('b', metadata,
            Column('id', Integer, primary_key=True),
            Column('aid', ForeignKey('a.id', ondelete="CASCADE"))
        )
    c = Table('c', metadata,
            Column('id', Integer, primary_key=True),
            Column('did', ForeignKey('d.id', ondelete="CASCADE"))
        )
    d = Table('d', metadata,
            Column('id', Integer, primary_key=True)
        )
    e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
    metadata.drop_all(e)
    metadata.create_all(e)

    reflected_metadata = MetaData()
    reflected_metadata.reflect(e)

    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base(metadata=reflected_metadata)

    classes = map_metadata(Base, reflected_metadata)

    A, B, C, D = classes["A"], classes["B"], classes["C"], classes["D"]
    from sqlalchemy.orm import Session
    sess = Session(e)

    sess.add_all([
        A(d=D(), b_collection=[B(), B()]),
        D(c_collection=[C(), C()])
    ])
    sess.commit()

    a1 = sess.query(A).first()
    assert len(a1.b_collection) == 2
    d1 = sess.query(D).filter(D.c_collection.any()).first()
    assert len(d1.c_collection) == 2

    sess.delete(d1)
    sess.delete(a1)
    sess.commit()

Updated