Mike Bayer avatar Mike Bayer committed c14bae3

- added add()/add_all() to Session; save/update/save_or_update on deck for deprecation
- transactional=True becomes autocommit=False
- little bunny of state-revertible SessionTransaction pokes its head out. tests not complete
and InstanceState.savepoint() implementation needs to be reworked for greater efficiency.
- removed some non-working cases for mapper.get_session(); the single remaining case, for lazy loads, is very close to being removed
- worst case scenario of "setup savepoints for every instance laoded", when in transaction and autoexpire is set on non-default of "off", adds relatively minimal overhead to masseagerload.py

Comments (0)

Files changed (15)

lib/sqlalchemy/orm/__init__.py

 from sqlalchemy.orm.session import Session as _Session
 from sqlalchemy.orm.session import object_session, sessionmaker
 from sqlalchemy.orm.scoping import ScopedSession
-
+from sqlalchemy import util as __util
 
 __all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload',
             'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer',
     It is recommended to use the [sqlalchemy.orm#sessionmaker()] function
     instead of create_session().
     """
+
+    if 'transactional' in kwargs:
+        __util.warn_deprecated("The 'transactional' argument to sessionmaker() is deprecated; use autocommit=True|False instead.")
+        autocommit = not kwargs.pop('transactional')
+    
     kwargs.setdefault('autoflush', False)
-    kwargs.setdefault('transactional', False)
+    kwargs.setdefault('autocommit', True)
+    kwargs.setdefault('autoexpire', False)
     return _Session(bind=bind, **kwargs)
 
 def relation(argument, secondary=None, **kwargs):

lib/sqlalchemy/orm/attributes.py

             ])
         for k in self.expired_attributes:
             self.callables.pop(k, None)
-        self.expired_attributes.clear()
+        del self.expired_attributes
         return ATTR_WAS_SET
 
     def unmodified(self):
         if attribute_names is None:
             attribute_names = self.manager.keys()
             self.expired = True
+            self.modified = False
         for key in attribute_names:
             self.dict.pop(key, None)
             self.committed_state.pop(key, None)
                 self.savepoints[-1][0][attr.key] = previous
                 
         self.modified = True
+    
+    # TODO: reorganize savepoints so that the savepoint record is managed externally to the InstanceState.
+    # InstanceState only references the most recent savepoint so that its available in modified_event().
+    
+    def set_savepoint(self, id_=None):
+        self.savepoints.append(({}, self.parents.copy(), self.pending.copy(), self.committed_state.copy(), self.modified, id_))
 
-    def set_savepoint(self):
-        self.savepoints.append(({}, self.parents.copy(), self.pending.copy(), self.committed_state.copy()))
+    def remove_savepoint(self, id_=None):
+        if not self.savepoints:
+            raise sa_exc.AssertionError("Savepoint id %s does not exist"  % (id_))
+            
+        sp = self.savepoints.pop()
+        spid = sp[5]
+        if spid != id_:
+            raise sa_exc.AssertionError("Savepoint id %s does not match %s"  % (spid, id_))
 
-    def remove_savepoint(self):
-        self.savepoints.pop()
-
-    def rollback(self):
+    def rollback(self, id_=None):
         if not self.savepoints:
             raise sa_exc.InvalidRequestError("No savepoints are set; can't rollback.")
         
-        (savepoint, self.parents, self.pending, self.committed_state) = self.savepoints.pop()
+        (savepoint, self.parents, self.pending, self.committed_state, self.modified, spid) = self.savepoints.pop()
+        if spid != id_:
+            raise sa_exc.AssertionError("Savepoint id %s does not match %s"  % (spid, id_))
         
         for attr in self.manager.attributes:
             if attr.impl.key in savepoint:

lib/sqlalchemy/orm/dynamic.py

     
     def session(self):
         return self.__session()
-    session = property(session)
+    session = property(session, lambda s, x:None)
     
     def __iter__(self):
         sess = self.__session()

lib/sqlalchemy/orm/mapper.py

                 state = attributes.instance_state(instance)
                 state.entity_name = self.entity_name
                 state.key = identitykey
+                # manually adding instance to session.  for a complete add,
+                # session._finalize_loaded() must be called.
                 state.session_id = context.session.hash_key
                 session_identity_map.add(state)
 
     mapper = _state_mapper(state)
     session = _state_session(state)
     if not session:
-        try:
-            session = mapper.get_session()
-        except sa_exc.InvalidRequestError:
-            raise sa_exc.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (state_str(state)))
+        raise sa_exc.UnboundExecutionError("Instance %s is not bound to a Session; attribute refresh operation cannot proceed" % (state_str(state)))
 
     has_key = _state_has_identity(state)
 

lib/sqlalchemy/orm/query.py

     """Encapsulates the object-fetching operations provided by Mappers."""
 
     def __init__(self, entities, session=None, entity_name=None):
-        self._session = session
+        self.session = session
         
         self._with_options = []
         self._lockmode = None
         q.__dict__ = self.__dict__.copy()
         return q
 
-    def session(self):
-        if self._session is None:
-            return self._mapper_zero().get_session()
-        else:
-            return self._session
-    session = property(session)
-
     def statement(self):
         """return the full SELECT statement represented by this Query."""
         return self._compile_context().statement
                 context.refresh_instance.commit(self._only_load_props)
                 context.progress.remove(context.refresh_instance)
 
-            for ii in context.progress:
-                ii.commit_all()
+            session._finalize_loaded(context.progress)
                 
             for ii, attrs in context.partials.items():
                 ii.commit(attrs)

lib/sqlalchemy/orm/session.py

 
 __all__ = ['Session', 'SessionTransaction', 'SessionExtension']
 
-
-def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs):
+def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, autoexpire=True, **kwargs):
     """Generate a custom-configured [sqlalchemy.orm.session#Session] class.
 
     The returned object is a subclass of ``Session``, which, when instantiated with no
     which should be used by the returned class.  All other keyword arguments sent to
     `sessionmaker()` are passed through to the instantiated `Session()` object.
     """
-
+    
+    if 'transactional' in kwargs:
+        util.warn_deprecated("The 'transactional' argument to sessionmaker() is deprecated; use autocommit=True|False instead.")
+        autocommit = not kwargs.pop('transactional')
+        
     kwargs['bind'] = bind
     kwargs['autoflush'] = autoflush
-    kwargs['transactional'] = transactional
+    kwargs['autocommit'] = autocommit
+    kwargs['autoexpire'] = autoexpire
 
     if class_ is None:
         class_ = Session
 
-    class Sess(class_):
+    class Sess(object):
         def __init__(self, **local_kwargs):
             for k in kwargs:
                 local_kwargs.setdefault(k, kwargs[k])
 
             kwargs.update(new_kwargs)
         configure = classmethod(configure)
-
-    return Sess
+    s = type.__new__(type, "Session", (Sess, class_), {})
+    return s
 
 
 class SessionTransaction(object):
         self.nested = nested
         self._active = True
         self._prepared = False
+        if not parent and nested:
+            raise sa_exc.InvalidRequestError("Can't start a SAVEPOINT transaction when no existing transaction is in progress")
+        self._take_snapshot()
 
     def is_active(self):
         return self.session is not None and self._active
     def _assert_is_open(self):
         if self.session is None:
             raise sa_exc.InvalidRequestError("The transaction is closed")
-
+    
+    def _is_transaction_boundary(self):
+        return self.nested or not self._parent
+    _is_transaction_boundary = property(_is_transaction_boundary)
+    
     def connection(self, bindkey, **kwargs):
         self._assert_is_active()
         engine = self.session.get_bind(bindkey, **kwargs)
             if self._parent is None:
                 raise sa_exc.InvalidRequestError("Transaction %s is not on the active transaction list" % upto)
             return (self,) + self._parent._iterate_parents(upto)
+    
+    def _take_snapshot(self):
+        if not self._is_transaction_boundary:
+            self._new = self._parent._new
+            self._deleted = self._parent._deleted
+            self._snapshot_id = self._parent._snapshot_id
+            return
+        
+        if self.autoflush:
+            assert not self.session._new
+            assert not self.session._deleted
+            assert not self.session._dirty_states
+        
+        self._new = weakref.WeakKeyDictionary()
+        self._deleted = weakref.WeakKeyDictionary()
+        self._snapshot_id = id_ = id(self)
+        
+        if self.nested or not self.session.autoexpire:
+            for state in self.session.identity_map.all_states():
+                state.set_savepoint(id_)
+    
+    def _restore_snapshot(self):
+        assert self._is_transaction_boundary
+        
+        for s in util.Set(self._deleted).union(self.session._deleted):
+            self.session._update_impl(s)
+            
+        for s in util.Set(self._new).union(self.session._new):
+            self.session._expunge_state(s)
+        
+        expire = not self.nested and self.session.autoexpire
+        id_ = self._snapshot_id
+        for s in self.session.identity_map.all_states():
+            if expire:
+                _expire_state(s, None)
+            else:
+                s.rollback(id_)
+    
+    def _remove_snapshot(self):
+        assert self._is_transaction_boundary
 
+        if self.nested or not self.session.autoexpire:
+            id_ = self._snapshot_id
+            for s in self.session.identity_map.all_states():
+                s.remove_savepoint(id_)
+            
     def add(self, bind):
         self._assert_is_active()
         if self._parent is not None and not self.nested:
 
             if self.session.extension is not None:
                 self.session.extension.after_commit(self.session)
-
+            
+            self._remove_snapshot()
+                
         self._close()
         return self._parent
     commit = util.deprecated()(_commit)
                     break
                 else:
                     transaction._deactivate()
+            
+                
         self._close()
         return self._parent
     rollback = util.deprecated()(_rollback)
         for t in util.Set(self._connections.values()):
             t[1].rollback()
 
+        self._restore_snapshot()
+
         if self.session.extension is not None:
             self.session.extension.after_rollback(self.session)
 
     a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function.
     """
 
-    def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
+    def __init__(self, bind=None, autoflush=True, autoexpire=True, autocommit=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None):
         """Construct a new Session.
         
         A session is usually constructed using the [sqlalchemy.orm#create_session()] function, 
         or its more "automated" variant [sqlalchemy.orm#sessionmaker()].
 
+        autoexpire
+            When ``True``, all instances will be fully expired after each ``rollback()``
+            and after each ``commit()``, so that all attribute/object access subsequent
+            to a completed transaction will load from the most recent database state.
+            
         autoflush
             When ``True``, all query operations will issue a ``flush()`` call to this
             ``Session`` before proceeding. This is a convenience feature so that
         self.transaction = None
         self.hash_key = id(self)
         self.autoflush = autoflush
-        self.transactional = transactional
+        self.autocommit = autocommit
+        self.autoexpire = autoexpire
         self.twophase = twophase
         self.extension = extension
         self._query_cls = query.Query
                     for t in mapperortable._all_tables:
                         self.__binds[t] = value
 
-        if self.transactional:
+        if not self.autocommit:
             self.begin()
         _sessions[self.hash_key] = self
 
             pass
         else:
             self.transaction._rollback()
-        if self.transaction is None and self.transactional:
+        if self.transaction is None and not self.autocommit:
             self.begin()
 
     def commit(self):
         """
 
         if self.transaction is None:
-            if self.transactional:
+            if not self.autocommit:
                 self.begin()
             else:
                 raise sa_exc.InvalidRequestError("No transaction is begun.")
 
         self.transaction._commit()
-        if self.transaction is None and self.transactional:
+        if self.transaction is None and not self.autocommit:
             self.begin()
     
     def prepare(self):
         not such, an InvalidRequestError is raised.
         """
         if self.transaction is None:
-            if self.transactional:
+            if not self.autocommit:
                 self.begin()
             else:
                 raise sa_exc.InvalidRequestError("No transaction is begun.")
         if self.transaction is not None:
             for transaction in self.transaction._iterate_parents():
                 transaction._close()
-        if self.transactional:
+        if not self.autocommit:
             # note this doesnt use any connection resources
             self.begin()
 
     def _autoflush(self):
         if self.autoflush and (self.transaction is None or self.transaction.autoflush):
             self.flush()
+    
+    def _finalize_loaded(self, states):
+        if not self.autoexpire and self.transaction:
+            snapshot_id = self.transaction._snapshot_id
+        else:
+            snapshot_id = None    
+        for state in states:
+            state.commit_all()
+            if snapshot_id:
+                state.set_savepoint(snapshot_id)
 
     def get(self, class_, ident, entity_name=None):
         """Return an instance of the object based on the given
         Cascading will be applied according to the *expunge* cascade
         rule.
         """
-        self._expunge_state(attributes.instance_state(instance))
         
+        state = attributes.instance_state(instance)
+        if state.session_id is not self.hash_key:
+            raise sa_exc.InvalidRequestError("Instance %s is not present in this Session" % mapperutil.state_str(state))
+        for s, m in [(state, None)] + list(_cascade_state_iterator('expunge', state)):
+            self._expunge_state(s)
+    
     def _expunge_state(self, state):
-        for s, m in [(state, None)] + list(_cascade_state_iterator('expunge', state)):
-            if s in self._new:
-                self._new.pop(s)
-                del s.session_id
-            elif self.identity_map.contains_state(s):
-                self._remove_persistent(s)
+        if state in self._new:
+            self._new.pop(state)
+            del state.session_id
+        elif self.identity_map.contains_state(state):
+            self.identity_map.discard(state)
+            self._deleted.pop(state, None)
+            del state.session_id
+
+    def _register_newly_persistent(self, state):
+        if self.transaction:
+            self.transaction._new[state] = True
+            
+        mapper = _state_mapper(state)
+        instance_key = mapper._identity_key_from_state(state)
+
+        if state.key is None:
+            state.key = instance_key
+        elif state.key != instance_key:
+            # primary key switch
+            self.identity_map.remove(state)
+            state.key = instance_key
+
+        if hasattr(state, 'insert_order'):
+            delattr(state, 'insert_order')
+
+        obj = state.obj()
+        # prevent against last minute dereferences of the object
+        # TODO: identify a code path where state.obj() is None
+        if obj is not None:
+            if state.key in self.identity_map and not self.identity_map.contains_state(state):
+                self.identity_map.remove_key(state.key)
+            self.identity_map.add(state)
+            state.commit_all()
+
+        # remove from new last, might be the last strong ref
+        self._new.pop(state, None)
         
-    def _remove_persistent(self, state):
+    def _remove_newly_deleted(self, state):
+        if self.transaction:
+            self.transaction._deleted[state] = True
+            
         self.identity_map.discard(state)
         self._deleted.pop(state, None)
         del state.session_id
         self._save_impl(state)
         self._cascade_save_or_update(state, entity_name)
     
+    # TODO
+    #save = util.deprecated("Use the add() method.")(save)
+    
     def _save_without_cascade(self, instance, entity_name=None):
         """used by scoping.py to save on init without cascade."""
         
         state = attributes.instance_state(instance)
         self._update_impl(state)
         self._cascade_save_or_update(state, entity_name)
-
-    def save_or_update(self, instance, entity_name=None):
-        """Save or update the given instance into this ``Session``.
+        
+    # TODO
+    #update = util.deprecated("Use the add() method.")(update)
+    
+    def add(self, instance, entity_name=None):
+        """Add the given instance into this ``Session``.
 
         The non-None state `key` on the instance's state determines whether
         to ``save()`` or ``update()`` the instance.
         """
         state = _state_for_unknown_persistence_instance(instance, entity_name)
         self._save_or_update_state(state, entity_name)
+    
+    def add_all(self, instances):
+        """Add the given collection of instances to this ``Session``."""
         
+        for instance in instances:
+            self.add(instance)
+        
+    # TODO
+    # save_or_update = util.deprecated("Use the add() method.")(add)
+    save_or_update = add
+    
     def _save_or_update_state(self, state, entity_name):
         self._save_or_update_impl(state)
         self._cascade_save_or_update(state, entity_name)
                 "Object '%s' is already attached to session '%s' "
                 "(this is '%s')" % (mapperutil.state_str(state),
                                     state.session_id, self.hash_key))
-        state.session_id = self.hash_key
+        if state.session_id != self.hash_key:
+            state.session_id = self.hash_key
+            if not self.autoexpire and self.transaction:
+                state.set_savepoint(self.transaction._snapshot_id)
 
     def __contains__(self, instance):
         """Return True if the given instance is associated with this session.
     def _contains_state(self, state):
         return state in self._new or self.identity_map.contains_state(state)
 
-    def _register_newly_persistent(self, state):
-
-        mapper = _state_mapper(state)
-        instance_key = mapper._identity_key_from_state(state)
-
-        if state.key is None:
-            state.key = instance_key
-        elif state.key != instance_key:
-            # primary key switch
-            self.identity_map.remove(state)
-            state.key = instance_key
-
-        if hasattr(state, 'insert_order'):
-            delattr(state, 'insert_order')
-
-        obj = state.obj()
-        # prevent against last minute dereferences of the object
-        # TODO: identify a code path where state.obj() is None
-        if obj is not None:
-            if state.key in self.identity_map and not self.identity_map.contains_state(state):
-                self.identity_map.remove_key(state.key)
-            self.identity_map.add(state)
-            state.commit_all()
-
-        # remove from new last, might be the last strong ref
-        self._new.pop(state, None)
 
     def flush(self, objects=None):
         """Flush all the object modifications present in this session
 
         if len(flush_context.tasks) == 0:
             return
-
-        self.create_transaction(autoflush=False)
-        flush_context.transaction = self.transaction
+        
+        flush_context.transaction = transaction = self.create_transaction(autoflush=False)
         try:
             flush_context.execute()
 
             if self.extension is not None:
                 self.extension.after_flush(self, flush_context)
-            self.commit()
+            transaction._commit()
         except:
-            self.rollback()
-            flush_context.remove_flush_changes()
+            transaction._rollback()
             raise
 
         flush_context.finalize_flush_changes()
     return state
 
 def _state_for_unknown_persistence_instance(instance, entity_name):
-    try:
-        state = attributes.instance_state(instance)
-        state.entity_name = entity_name
-        return state
-    except AttributeError:
-        return self._state_for_unsaved_instance(instance, entity_name)
+    state = attributes.instance_state(instance)
+    state.entity_name = entity_name
+    return state
 
 def object_session(instance):
     """Return the ``Session`` to which the given instance is bound, or ``None`` if none."""

lib/sqlalchemy/orm/unitofwork.py

         if postupdate:
             task.append_postupdate(state, post_update_cols)
         else:
-            task.append(state, listonly=listonly, apply_savepoint=True, isdelete=isdelete)
+            task.append(state, listonly=listonly, isdelete=isdelete)
 
     def set_row_switch(self, state):
         """mark a deleted object as a 'row switch'.
                 yield elem
     elements = property(elements)
     
-    def remove_flush_changes(self):
-        for elem in self.elements:
-            elem.state.rollback()
-            
     def finalize_flush_changes(self):
         """mark processed objects as clean / deleted after a successful flush().
         
         """
 
         for elem in self.elements:
-            elem.state.remove_savepoint()
             if elem.isdelete:
-                self.session._remove_persistent(elem.state)
+                self.session._remove_newly_deleted(elem.state)
             else:
                 self.session._register_newly_persistent(elem.state)
 
 
         return not self._objects and not self.dependencies
             
-    def append(self, state, apply_savepoint=False, listonly=False, isdelete=False):
+    def append(self, state, listonly=False, isdelete=False):
         if state not in self._objects:
             self._objects[state] = rec = UOWTaskElement(state)
-            if apply_savepoint:
-                state.set_savepoint()
         else:
             rec = self._objects[state]
         

test/orm/alltests.py

         'orm.naturalpks',
         'orm.unitofwork',
         'orm.session',
+        'orm.transaction',
         'orm.scoping',
         'orm.cascade',
         'orm.relationships',

test/orm/expire.py

         
         # object is gone, get() raises
         self.assertRaises(orm_exc.ObjectDeletedError, s.get, User, 10)
+    
+    def test_refresh_cancels_expire(self):
+        mapper(User, users)
+        s = create_session()
+        u = s.get(User, 7)
+        s.expire(u)
+        s.refresh(u)
+        
+        def go():
+            u = s.get(User, 7)
+            self.assertEquals(u.name, 'jack')
+        self.assert_sql_count(testing.db, go, 0)
         
     def test_expire_doesntload_on_set(self):
         mapper(User, users)

test/orm/mapper.py

         m.add_property('uc_user_name2', comparable_property(
                 UCComparator, User.uc_user_name2))
 
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
         assert sess.query(User).get(7)
 
         u = sess.query(User).filter_by(user_name='jack').one()

test/orm/session.py

         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(transactional=True, bind=conn1)
+        sess = create_session(autocommit=False, bind=conn1)
         u = User()
         sess.save(u)
         sess.flush()
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
         })
         mapper(Address, addresses)
 
-        sess = create_session(autoflush=True, transactional=True)
+        sess = create_session(autoflush=True, autocommit=False)
         u = User(user_name='ed', addresses=[Address(email_address='foo')])
         sess.save(u)
         self.assertEquals(sess.query(Address).filter(Address.user==u).one(), Address(email_address='foo'))
         mapper(User, users)
 
         try:
-            sess = create_session(transactional=True, autoflush=True)
+            sess = create_session(autocommit=False, autoflush=True)
             u = User()
             u.user_name='ed'
             sess.save(u)
         conn1 = testing.db.connect()
         conn2 = testing.db.connect()
 
-        sess = create_session(bind=conn1, transactional=True, autoflush=True)
+        sess = create_session(bind=conn1, autocommit=False, autoflush=True)
         u = User()
         u.user_name='ed'
         sess.save(u)
             'addresses':relation(Address)
         })
 
-        sess = create_session(transactional=True, autoflush=True)
+        sess = create_session(autocommit=False, autoflush=True)
         u = sess.query(User).get(8)
         newad = Address()
         newad.email_address == 'something new'
         mapper(User, users)
         conn = testing.db.connect()
         trans = conn.begin()
-        sess = create_session(bind=conn, transactional=True, autoflush=True)
+        sess = create_session(bind=conn, autocommit=False, autoflush=True)
         sess.begin()
         u = User()
         sess.save(u)
         try:
             conn = testing.db.connect()
             trans = conn.begin()
-            sess = create_session(bind=conn, transactional=True, autoflush=True)
+            sess = create_session(bind=conn, autocommit=False, autoflush=True)
             u1 = User()
             sess.save(u1)
             sess.flush()
     def test_joined_transaction(self):
         class User(object):pass
         mapper(User, users)
-        sess = create_session(transactional=True, autoflush=True)
+        sess = create_session(autocommit=False, autoflush=True)
         sess.begin()
         u = User()
         sess.save(u)
     def test_nested_autotrans(self):
         class User(object):pass
         mapper(User, users)
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
         class User(object): pass
         mapper(User, users)
 
-        sess = create_session(transactional=True)
+        sess = create_session(autocommit=False)
 
         sess.begin_nested()
 
         mapper(User, users)
         c = testing.db.connect()
 
-        sess = create_session(bind=c, transactional=True)
+        sess = create_session(bind=c, autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
         assert not c.in_transaction()
         assert c.scalar("select count(1) from users") == 0
 
-        sess = create_session(bind=c, transactional=True)
+        sess = create_session(bind=c, autocommit=False)
         u = User()
         sess.save(u)
         sess.flush()
         c = testing.db.connect()
 
         trans = c.begin()
-        sess = create_session(bind=c, transactional=False)
+        sess = create_session(bind=c, autocommit=True)
         u = User()
         sess.save(u)
         sess.flush()
         assert log == ['before_flush', 'after_flush', 'before_commit', 'after_commit', 'after_flush_postexec']
 
         log = []
-        sess = create_session(transactional=True, extension=MyExt())
+        sess = create_session(autocommit=False, extension=MyExt())
         u = User()
         sess.save(u)
         sess.flush()

test/orm/sharding/shard.py

             else:
                 return ids
 
-        create_session = sessionmaker(class_=ShardedSession, autoflush=True, transactional=True)
+        create_session = sessionmaker(class_=ShardedSession, autoflush=True, autocommit=False)
 
         create_session.configure(shards={
             'north_america':db1,

test/orm/transaction.py

+import testenv; testenv.configure_for_tests()
+import operator
+from sqlalchemy import *
+from sqlalchemy import exc as sa_exc
+from sqlalchemy.orm import *
+from testlib.testing import *
+from testlib.fixtures import *
+
+class TransactionTest(FixtureTest):
+    keep_mappers = True
+    refresh_data = True
+
+    def setup_mappers(self):
+        mapper(User, users, properties={
+            'addresses':relation(Address, backref='user'),
+            })
+        mapper(Address, addresses)
+
+class ActiveRollbackTest(object):
+    def test_attrs_on_rollback(self):
+        sess = self.session()
+        u1 = sess.get(User, 7)
+        u1.name = 'ed'
+        sess.rollback()
+        self.assertEquals(u1.name, 'jack')
+    
+    def test_expunge_pending_on_rollback(self):
+        sess = self.session()
+        u2= User(name='newuser')
+        sess.add(u2)
+        assert u2 in sess
+        sess.rollback()
+        assert u2 not in sess
+    
+    def test_commit_persistent(self):
+        sess = self.session()
+        u1 = sess.get(User, 7)
+        u1.name = 'ed'
+        sess.flush()
+        sess.commit()
+        self.assertEquals(u1.name, 'ed')
+
+    def test_commit_pending(self):
+        sess = self.session()
+        u2 = User(name='newuser')
+        sess.add(u2)
+        sess.flush()
+        sess.commit()
+        self.assertEquals(u2.name, 'newuser')
+        
+# TODO!  subtransactions
+# TODO!  SAVEPOINT transactions
+# TODO!  continuing transactions after rollback()
+
+class AutoExpireTest(ActiveRollbackTest, TransactionTest):
+    def session(self):
+        return create_session(autoflush=True, autocommit=False, autoexpire=True)
+
+class AttrSavepointTest(ActiveRollbackTest, TransactionTest):
+    def session(self):
+        return create_session(autoflush=True, autocommit=False, autoexpire=False)
+
+
+class AutocommitTest(TestBase):
+    def test_begin_nested_requires_trans(self):
+        sess = create_session(autocommit=True)
+        self.assertRaises(sa_exc.InvalidRequestError, sess.begin_nested)
+
+
+
+if __name__ == '__main__':
+    testenv.main()

test/orm/unitofwork.py

 
 # TODO: convert suite to not use Session.mapper, use fixtures.Base
 # with explicit session.save()
-Session = scoped_session(sessionmaker(autoflush=True, transactional=True))
+Session = scoped_session(sessionmaker(autoflush=True, autocommit=False))
 orm_mapper = mapper
 mapper = Session.mapper
 
         s2.commit()
 
         f1.value='f1rev3mine'
-        success = False
-        try:
-            # a concurrent session has modified this, should throw
-            # an exception
-            s.commit()
-        except orm_exc.ConcurrentModificationError, e:
-            #print e
-            success = True
 
         # Only dialects with a sane rowcount can detect the ConcurrentModificationError
         if testing.db.dialect.supports_sane_rowcount:
-            assert success
-
-        s.close()
+            self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+            s.rollback()
+        else:
+            s.commit()
+        
+        # new in 0.5 !  dont need to close the session
         f1 = s.query(Foo).get(f1.id)
         f2 = s.query(Foo).get(f2.id)
 
 
         s.delete(f1)
         s.delete(f2)
-        success = False
-        try:
+
+        if testing.db.dialect.supports_sane_multi_rowcount:
+            self.assertRaises(orm_exc.ConcurrentModificationError, s.commit)
+        else:
             s.commit()
-        except orm_exc.ConcurrentModificationError, e:
-            #print e
-            success = True
-        if testing.db.dialect.supports_sane_multi_rowcount:
-            assert success
 
     @engines.close_open_connections
     def test_versioncheck(self):
         f1s2 = s2.query(Foo).get(f1s1.id)
         f1s2.value='f1 new value'
         s2.commit()
-        try:
-            # load, version is wrong
-            s1.query(Foo).with_lockmode('read').get(f1s1.id)
-            assert False
-        except orm_exc.ConcurrentModificationError, e:
-            assert True
+        # load, version is wrong
+        self.assertRaises(orm_exc.ConcurrentModificationError, s1.query(Foo).with_lockmode('read').get, f1s1.id)
+
         # reload it
         s1.query(Foo).load(f1s1.id)
         # now assert version OK
         orm_mapper(T2, t2)
 
     def test_close_transaction_on_commit_fail(self):
-        Session = sessionmaker(autoflush=False, transactional=False)
+        Session = sessionmaker(autoflush=False, autocommit=True)
         sess = Session()
 
         # with a deferred constraint, this fails at COMMIT time instead

test/perf/masseagerload.py

 
 @profiling.profiled('masseagerload', always=True, sort=['cumulative'])
 def masseagerload(session):
+    session.begin()
     query = session.query(Item)
     l = query.all()
+    session.commit()
     print "loaded ", len(l), " items each with ", len(l[0].subs), "subitems"
 
 def all():
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.