polymorphic_class example

Issue #1951 resolved
Mike Bayer repo owner created an issue

this can go on the wiki, once #1950 is complete:

from sqlalchemy import util
from sqlalchemy.types import TypeDecorator
from sqlalchemy.orm import class_mapper

try:
    from sqlalchemy.sql.expression import type_coerce
except ImportError:
    from sqlalchemy.sql.expression import _Label
    def type_coerce(expr, type_):
        return _Label(None, expr, type_=type_)


class polymorphic_class(object):

    @util.memoized_instancemethod
    def _from_owner(self, owner):
        mapper = class_mapper(owner)
        # need that bidict here
        poly_map = dict(
            [m.class_) for ident, m in mapper.polymorphic_map.items()]((ident,)
        )
        poly_map.update(
            [ident) for ident, m in mapper.polymorphic_map.items()]((m.class_,)
        )

        class PolyType(TypeDecorator):
            impl = type(mapper.polymorphic_on.type)

            def process_bind_param(self, value, dialect):
                if value is not None:
                    return poly_map[value](value)

            def process_result_value(self, value, dialect):
                if value is not None:
                    return poly_map[value](value)
        return type_coerce(
                    mapper.polymorphic_on, 
                    type_=PolyType())

    def __get__(self, instance, owner):
        if instance is not None:
            return instance.__class__

        return self._from_owner(owner)

if __name__ == '__main__':
    from sqlalchemy.ext.declarative import declarative_base
    from sqlalchemy import create_engine, Column, String, Integer
    from sqlalchemy.orm import Session

    Base = declarative_base()
    class Person(Base):
        __tablename__= 'person'
        id = Column(Integer, primary_key=True)
        name = Column(String(50))
        type_ = Column(Integer)

        cls_type = polymorphic_class()

        __mapper_args__ = {'polymorphic_on':type_}

    class Engineer(Person):
        __mapper_args__ = {'polymorphic_identity':1}

    class Manager(Person):
        __mapper_args__ = {'polymorphic_identity':2}

    engine = create_engine('sqlite://', echo=True)
    Base.metadata.create_all(engine)

    sess = Session(engine)

    e1, e2, m1 = Engineer(name='dilbert'), Engineer(name='wally'), Manager(name='dogbert')
    assert e2.cls_type is Engineer

    sess.add_all([e2, m1](e1,))

    # this lambda can go away when #1950 is complete
    assert \
        sess.query(Person.name, Person.cls_type).\
                filter(Person.cls_type==(lambda: Engineer)).\
                order_by(Person.name).\
                all(), \
        [Engineer), ('wally', Engineer)](('dilbert',)

Comments (2)

  1. Log in to comment