Clone wiki

sqlalchemy / UsageRecipes / SessionIndexing


This recipe presents a generalized way to "index" objects in memory as they are placed into Sessions, so that later they can be retrieved based on particular criteria. The use case for this could be to assist in writing before_flush() event handlers, where particular subsets of objects in a Session need to be inspected, and a SQL round trip is specifically not wanted, typically due to performance concerns.

The technique is actually pretty simplistic, and does not account for the case where the objects are mutated in the Session, such that the object would be indexed differently. To handle that, attribute-on-change events would also need to be intercepted, resulting in a re-indexing of a particular target index.

In constrast to this recipe, it is of course vastly simpler just to use the Session normally, emitting a query against the database whose results are then correlated against what's already in the Sessions' identity map; the use case here is specifically one of avoiding those round trips.

import weakref
import collections
from sqlalchemy import event
from sqlalchemy.orm import Session
from sqlalchemy.orm import mapper

class Index(object):
    """An in-memory 'index' of objects in sessions.

    Listens for objects being attached to sessions and
    indexes them according to a series of user-defined "indexing"


    def __init__(self):
        # dictionary of (name of index -> how to index)
        self._index_fns = weakref.WeakKeyDictionary()

        # dictionary of (session object ->
        #                  dictionary of
        #                      ((indexname, value) -> set of instances)
        #                )
        self._by_session = weakref.WeakKeyDictionary()

        @event.listens_for(mapper, "load")
        def object_loaded(instance, ctx):
            self._index_object(ctx.session, instance)

        @event.listens_for(Session, "after_attach")
        def index_object(session, instance):
            self._index_object(session, instance)

    def _index_object(self, session, instance):
        # object attached to a session

        # get a dictionary for this session
        if session not in self._by_session:
            # per session we store a dictionary of sets
            self._by_session[session] = by_session = \
            by_session = self._by_session[session]

        # find all the indexes for this object's class,
        # and superclasses too.
        typ = type(instance)
        for cls in typ.__mro__:
            if cls in self._index_fns:
                # all the "index" functions for this class
                for name, rec in self._index_fns[cls].items():
                    if rec['include_subclasses'] or cls is rec['cls']:
                        # call the indexing function, build a key
                        key = name, rec['fn'](instance)

    def indexed(self, cls, name, include_subclasses=True):
        """Log a function as indexing a certain class."""

        if cls not in self._index_fns:
            self._index_fns[cls] = byclass = {}
            byclass = self._index_fns[cls]
        def decorate(fn):
            byclass[name] = {
                "fn": fn,
                "cls": cls,
                "include_subclasses": include_subclasses
            return fn
        return decorate

    def __getattr__(self, name):
        """Return an index-lookup function."""

        def go(sess, value):
            by_session = self._by_session.get(sess)
            if by_session is None:
                return set()
            key = name, value
            return set(by_session[key]).intersection(
        return go

indexes = Index()

if __name__ == '__main__':
    # demonstration

    from sqlalchemy import Column, String, Integer
    from sqlalchemy.orm import Session
    from sqlalchemy.ext.declarative import declarative_base

    Base = declarative_base()

    class User(Base):
        __tablename__ = 'user'

        id = Column(Integer, primary_key=True)
        name = Column(String)

    class Address(Base):
        __tablename__ = 'address'

        id = Column(Integer, primary_key=True)
        name = Column(String)

    @indexes.indexed(User, "user_byname")
    def index_user_byname(obj):

    @indexes.indexed(Address, "address_byname")
    def index_address_byname(obj):

    a1, a2, a3 = User(name='a'), User(name='a'), User(name='a')
    b1, b2, b3 = User(name='b'), User(name='b'), User(name='b')
    c1, c2, c3 = User(name='c'), User(name='c'), User(name='c')
    d1, d2, d3 = User(name='d'), User(name='d'), User(name='d')
    e1, e2, e3 = User(name='e'), User(name='e'), User(name='e')

    ad_a, ad_b, ad_c = Address(name='a'), Address(name='b'), Address(name='c')

    s1, s2, s3 = Session(), Session(), Session()

    s1.add_all([a1, b1, b2, d2, e3, ad_c])
    s2.add_all([a2, c2, e1, e2, ad_a])
    s3.add_all([b3, c1, d1, d3, ad_b])

    assert indexes.user_byname(s1, "b") == set([b1, b2])
    assert indexes.user_byname(s2, "e") == set([e1, e2])
    assert indexes.user_byname(s2, "c") == set([c2])
    assert indexes.user_byname(s3, "b") == set([b3])
    assert indexes.address_byname(s3, "b") == set([ad_b])
    assert indexes.address_byname(s3, "c") == set()
    assert indexes.address_byname(s1, "c") == set([ad_c])

    assert indexes.user_byname(s2, "e") == set([e1])
    assert indexes.user_byname(s2, "e") == set([])