Commits

Mike Bayer committed ee753d7

latest reorgnanization of the objectstore, the Session is a simpler object that just maintains begin/commit state

Comments (0)

Files changed (5)

lib/sqlalchemy/mapping/mapper.py

             oldinit = self.class_.__init__
             def init(self, *args, **kwargs):
                 nohist = kwargs.pop('_mapper_nohistory', False)
-                session = kwargs.pop('_sa_session', objectstore.session())
+                session = kwargs.pop('_sa_session', objectstore.get_session())
                 if oldinit is not None:
                     try:
                         oldinit(self, *args, **kwargs)
                 
         # store new stuff in the identity map
         for value in imap.values():
-            objectstore.session().register_clean(value)
+            objectstore.get_session().register_clean(value)
 
         if len(mappers):
             return [result] + otherresults
         
     def _get(self, key, ident=None):
         try:
-            return objectstore.session()._get(key)
+            return objectstore.get_session()._get(key)
         except KeyError:
             if ident is None:
                 ident = key[2]
         # including modifying any of its related items lists, as its already
         # been exposed to being modified by the application.
         identitykey = self._identity_key(row)
-        if objectstore.session().has_key(identitykey):
-            instance = objectstore.session()._get(identitykey)
+        if objectstore.get_session().has_key(identitykey):
+            instance = objectstore.get_session()._get(identitykey)
 
             isnew = False
             if populate_existing:

lib/sqlalchemy/mapping/objectstore.py

     The registry is capable of maintaining object instances on a thread-local, 
     per-application, or custom user-defined basis."""
     
-    def __init__(self, scope="application", getter=None, hash_key=None, keyfunc=None):
+    def __init__(self, nest_transactions=False, hash_key=None):
         """Initialize the objectstore with a UnitOfWork registry.  If called
         with no arguments, creates a single UnitOfWork for all operations.
         
-        scope - "application" or "thread", the two default scopes
-        getter - a callable that takes this Session as an argument and returns a 
-        new UnitOfWork.
+        nest_transactions - indicates begin/commit statements can be executed in a
+        "nested", defaults to False which indicates "only commit on the outermost begin/commit"
         hash_key - the hash_key used to identify objects against this session, which 
         defaults to the id of the Session instance.
-        keyfunc - allows custom scopes by providing a callable to return the "key"
-        identifying the desired UnitOfWork.
         """
-        if keyfunc is None:
-            if scope=="thread":
-                keyfunc = thread.get_ident
-            elif scope=="application":
-                keyfunc = lambda: True
-        if getter is None:
-            def createfunc():
-                return UnitOfWork(self)
+        self.uow = UnitOfWork()
+        self.parent_uow = None
+        self.begin_count = 0
+        self.nest_transactions = nest_transactions
+        if hash_key is None:
+            self.hash_key = id(self)
         else:
-            createfunc = lambda: getter(self)
-        self.registry = util.ScopedRegistry(createfunc, keyfunc)
-        self._hash_key = hash_key
-
+            self.hash_key = hash_key
+        _sessions[self.hash_key] = self
+        
     def get_id_key(ident, class_, table):
         """returns an identity-map key for use in storing/retrieving an item from the identity
         map, given a tuple of the object's primary key values.
         return (class_, table.hash_key(), tuple([row[column] for column in primary_key]))
     get_row_key = staticmethod(get_row_key)
     
-    def _set_uow(self, uow):
-        self.registry.set(uow)
-    uow = property(lambda s:s.registry(), _set_uow, doc="Returns a scope-specific UnitOfWork object for this session.")
-    
-    hash_key = property(lambda s:s._hash_key or id(s))
+    def begin(self):
+        """begins a new UnitOfWork transaction.  the next commit will affect only
+        objects that are created, modified, or deleted following the begin statement."""
+        self.begin_count += 1
+        if self.parent_uow is not None:
+            return
+        self.parent_uow = self.uow            
+        self.uow = UnitOfWork(identity_map = self.uow.identity_map)
+        
+    def commit(self, *objects):
+        """commits the current UnitOfWork transaction.  if a transaction was begun 
+        via begin(), commits only those objects that were created, modified, or deleted
+        since that begin statement.  otherwise commits all objects that have been
+        changed.
+        if individual objects are submitted, then only those objects are committed, and the 
+        begin/commit cycle is not affected."""
+        # if an object list is given, commit just those but dont
+        # change begin/commit status
+        if len(objects):
+            self.uow.commit(*objects)
+            return
+        if self.parent_uow is not None:
+            self.begin_count -= 1
+            if self.begin_count > 0:
+                return
+        self.uow.commit()
+        if self.parent_uow is not None:
+            self.uow = self.parent_uow
+            self.parent_uow = None
 
-    def bind_to(self, obj):
+    def rollback(self):
+        """rolls back the current UnitOfWork transaction, in the case that begin()
+        has been called.  The changes logged since the begin() call are discarded."""
+        if self.parent_uow is None:
+            raise "UOW transaction is not begun"
+        self.uow = self.parent_uow
+        self.parent_uow = None
+        self.begin_count = 0
+        
+    def register_clean(self, obj):
+        self._bind_to(obj)
+        self.uow.register_clean(obj)
+        
+    def register_new(self, obj):
+        self._bind_to(obj)
+        self.uow.register_new(obj)
+
+    def _bind_to(self, obj):
         """given an object, binds it to this session.  changes on the object will affect
         the currently scoped UnitOfWork maintained by this session."""
         obj._sa_session_id = self.hash_key
 
     def __getattr__(self, key):
         """proxy other methods to our underlying UnitOfWork"""
-        return getattr(self.registry(), key)
+        return getattr(self.uow, key)
 
     def clear(self):
-        self.registry.clear()
+        self.uow = UnitOfWork()
 
-    def delete(*obj):
+    def delete(self, *obj):
         """registers the given objects as to be deleted upon the next commit"""
-        u = registry()
         for o in obj:
-            u.register_deleted(o)
+            self.uow.register_deleted(o)
         
     def import_instance(self, instance):
         """places the given instance in the current thread's unit of work context,
         key = getattr(instance, '_instance_key', None)
         mapper = object_mapper(instance)
         key = (key[0], mapper.table.hash_key(), key[2])
-        u = self.registry()
+        u = self.uow
         if key is not None:
             if u.identity_map.has_key(key):
                 return u.identity_map[key]
         else:
             u.register_new(instance)
         return instance
-    
 
 def get_id_key(ident, class_, table):
     return Session.get_id_key(ident, class_, table)
 def begin():
     """begins a new UnitOfWork transaction.  the next commit will affect only
     objects that are created, modified, or deleted following the begin statement."""
-    session().begin()
+    get_session().begin()
 
 def commit(*obj):
     """commits the current UnitOfWork transaction.  if a transaction was begun 
     via begin(), commits only those objects that were created, modified, or deleted
     since that begin statement.  otherwise commits all objects that have been
-    changed."""
-    session().commit(*obj)
+    changed.
+    
+    if individual objects are submitted, then only those objects are committed, and the 
+    begin/commit cycle is not affected."""
+    get_session().commit(*obj)
 
 def clear():
     """removes all current UnitOfWorks and IdentityMaps for this thread and 
     establishes a new one.  It is probably a good idea to discard all
     current mapped object instances, as they are no longer in the Identity Map."""
-    session().clear()
+    get_session().clear()
 
 def delete(*obj):
     """registers the given objects as to be deleted upon the next commit"""
-    s = session()
-    for o in obj:
-        s.register_deleted(o)
+    s = get_session().delete(*obj)
 
 def has_key(key):
     """returns True if the current thread-local IdentityMap contains the given instance key"""
-    return session().has_key(key)
+    return get_session().has_key(key)
 
 def has_instance(instance):
     """returns True if the current thread-local IdentityMap contains the given instance"""
-    return session().has_instance(instance)
+    return get_session().has_instance(instance)
 
 def is_dirty(obj):
     """returns True if the given object is in the current UnitOfWork's new or dirty list,
     or if its a modified list attribute on an object."""
-    return session().is_dirty(obj)
+    return get_session().is_dirty(obj)
 
 def instance_key(instance):
     """returns the IdentityMap key for the given instance"""
-    return session().instance_key(instance)
+    return get_session().instance_key(instance)
 
 def import_instance(instance):
-    return session().import_instance(instance)
+    return get_session().import_instance(instance)
 
 class UOWListElement(attributes.ListElement):
     def __init__(self, obj, key, data=None, deleteremoved=False, **kwargs):
         attributes.ListElement.__init__(self, obj, key, data=data, **kwargs)
         self.deleteremoved = deleteremoved
     def list_value_changed(self, obj, key, item, listval, isdelete):
-        sess = session(obj)
+        sess = get_session(obj)
         if not isdelete and sess.deleted.contains(item):
             raise "re-inserting a deleted value into a list"
         sess.modified_lists.append(self)
         
     def value_changed(self, obj, key, value):
         if hasattr(obj, '_instance_key'):
-            session(obj).register_dirty(obj)
+            get_session(obj).register_dirty(obj)
         else:
-            session(obj).register_new(obj)
+            get_session(obj).register_new(obj)
 
     def create_list(self, obj, key, list_, **kwargs):
         return UOWListElement(obj, key, list_, **kwargs)
         
 class UnitOfWork(object):
-    def __init__(self, session, parent=None, is_begun=False):
-        self.session = session
-        self.is_begun = is_begun
-        if is_begun:
-            self.begin_count = 1
-        else:
-            self.begin_count = 0
-        if parent is not None:
-            self.identity_map = parent.identity_map
+    def __init__(self, identity_map=None):
+        if identity_map is not None:
+            self.identity_map = identity_map
         else:
             self.identity_map = weakref.WeakValueDictionary()
             
         self.dirty = util.HashSet()
         self.modified_lists = util.HashSet()
         self.deleted = util.HashSet()
-        self.parent = parent
 
     def get(self, class_, *id):
         """given a class and a list of primary key values in their table-order, locates the mapper 
         if not hasattr(obj, '_instance_key'):
             mapper = object_mapper(obj)
             obj._instance_key = mapper.instance_key(obj)
-        self.session.bind_to(obj)
         self._put(obj._instance_key, obj)
         self.attributes.commit(obj)
         
     def register_new(self, obj):
-        self.session.bind_to(obj)
         self.new.append(obj)
         
     def register_dirty(self, obj):
         except KeyError:
             pass
             
-    # TODO: tie in register_new/register_dirty with table transaction begins ?
-    def begin(self):
-        if self.is_begun:
-            self.begin_count += 1
-            return
-        u = UnitOfWork(self.session, parent=self, is_begun=True)
-        self.session.registry.set(u)
-        
     def commit(self, *objects):
-        if self.is_begun:
-            self.begin_count -= 1
-            if self.begin_count > 0:
-                return
         commit_context = UOWTransaction(self)
 
         if len(objects):
         except:
             for e in engines:
                 e.rollback()
-            if self.parent:
-                self.session.registry.set(self.parent)
             raise
         for e in engines:
             e.commit()
             
         commit_context.post_exec()
         
-        if self.parent:
-            self.session.registry.set(self.parent)
 
     def rollback_object(self, obj):
         """'rolls back' the attributes that have been changed on an object instance."""
 
 global_attributes = UOWAttributeManager()
 
-global_session = Session(scope="thread", hash_key='thread')
-uow = global_session.registry # Note: this is not a UnitOfWork, it is a ScopedRegistry that manages UnitOfWork objects
 
-_sessions = weakref.WeakValueDictionary()
-_sessions[global_session.hash_key] = global_session
+session_registry = util.ScopedRegistry(Session) # Default session registry
+_sessions = weakref.WeakValueDictionary() # all referenced sessions (including user-created)
 
-def session(obj=None):
+def get_session(obj=None):
     # object-specific session ?
     if obj is not None:
         # does it have a hash key ?
             except KeyError:
                 raise "Session '%s' referenced by object '%s' no longer exists" % (hashkey, repr(obj))
 
-    try:
-        # have a thread-locally defined session (via using_session) ?
-        return _sessions[thread.get_ident()]
-    except KeyError:
-        # nope, return the regular session
-        return global_session
+    return session_registry()
+
+uow = get_session # deprecated
 
 def push_session(sess):
     old = _sessions.get(thread.get_ident(), None)

lib/sqlalchemy/util.py

     def __init__(self, createfunc, scopefunc=None):
         self.createfunc = createfunc
         if scopefunc is None:
-            scopefunc = thread.get_ident
+            self.scopefunc = thread.get_ident
         else:
             self.scopefunc = scopefunc
         self.registry = {}

test/objectstore.py

         u = m.select()[0]
         print u.addresses[0].user
 
+class SessionTest(AssertMixin):
+    def setUpAll(self):
+        db.echo = False
+        users.create()
+        tables.user_data()
+        db.echo = testbase.echo
+    def tearDownAll(self):
+        db.echo = False
+        users.drop()
+        db.echo = testbase.echo
+    def setUp(self):
+        objectstore.get_session().clear()
+        clear_mappers()
+        
+    def test_nested_begin_commit(self):
+        """test nested session.begin/commit"""
+        class User(object):pass
+        m = mapper(User, users)
+        def name_of(id):
+            return users.select(users.c.user_id == id).execute().fetchone().user_name
+        name1 = "Oliver Twist"
+        name2 = 'Mr. Bumble'
+        self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+        self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+        s = objectstore.get_session()
+        s.begin()
+        s.begin()
+        m.get(7).user_name = name1
+        s.begin()
+        m.get(8).user_name = name2
+        s.commit()
+        self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+        self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+        s.commit()
+        self.assert_(name_of(7) != name1, msg="user_name should not be %s" % name1)
+        self.assert_(name_of(8) != name2, msg="user_name should not be %s" % name2)
+        s.commit()
+        self.assert_(name_of(7) == name1, msg="user_name should be %s" % name1)
+        self.assert_(name_of(8) == name2, msg="user_name should be %s" % name2)
+
+
 class PKTest(AssertMixin):
     def setUpAll(self):
         db.echo = False
     users.delete().execute()
     db.commit()
     
+def user_data():
+    users.insert().execute(
+        dict(user_id = 7, user_name = 'jack'),
+        dict(user_id = 8, user_name = 'ed'),
+        dict(user_id = 9, user_name = 'fred')
+    )
+    
 def data():
     delete()