Mike Bayer avatar Mike Bayer committed 438b18e

- add some more transaction states so that we deliver a more accurate
message for [ticket:2662]; after_commit() is called within "committed"
state, not prepared, and no SQL can be emitted for prepared or committed
- consolidate state assertions in session transaction, use just one
method
- add more unit tests for these assertions

Comments (0)

Files changed (2)

lib/sqlalchemy/orm/session.py

 
 ACTIVE = util.symbol('ACTIVE')
 PREPARED = util.symbol('PREPARED')
+COMMITTED = util.symbol('COMMITTED')
 DEACTIVE = util.symbol('DEACTIVE')
+CLOSED = util.symbol('CLOSED')
 
 class SessionTransaction(object):
     """A :class:`.Session`-level transaction.
     def is_active(self):
         return self.session is not None and self._state is ACTIVE
 
-    def _assert_is_active(self):
-        self._assert_is_open()
-        if self._state is PREPARED:
+    def _assert_active(self, prepared_ok=False,
+                        rollback_ok=False,
+                        closed_msg="This transaction is closed"):
+        if self._state is COMMITTED:
             raise sa_exc.InvalidRequestError(
-                    "This session is in 'prepared' state, where no further "
-                    "SQL can be emitted until the transaction is fully "
-                    "committed."
+                    "This session is in 'committed' state; no further "
+                    "SQL can be emitted within this transaction."
                 )
+        elif self._state is PREPARED:
+            if not prepared_ok:
+                raise sa_exc.InvalidRequestError(
+                        "This session is in 'prepared' state; no further "
+                        "SQL can be emitted within this transaction."
+                    )
         elif self._state is DEACTIVE:
-            if self._rollback_exception:
-                raise sa_exc.InvalidRequestError(
-                    "This Session's transaction has been rolled back "
-                    "due to a previous exception during flush."
-                    " To begin a new transaction with this Session, "
-                    "first issue Session.rollback()."
-                    " Original exception was: %s"
-                    % self._rollback_exception
-                )
-            else:
-                raise sa_exc.InvalidRequestError(
-                    "This Session's transaction has been rolled back "
-                    "by a nested rollback() call.  To begin a new "
-                    "transaction, issue Session.rollback() first."
+            if not rollback_ok:
+                if self._rollback_exception:
+                    raise sa_exc.InvalidRequestError(
+                        "This Session's transaction has been rolled back "
+                        "due to a previous exception during flush."
+                        " To begin a new transaction with this Session, "
+                        "first issue Session.rollback()."
+                        " Original exception was: %s"
+                        % self._rollback_exception
                     )
-
-    def _assert_is_open(self, error_msg="The transaction is closed"):
-        if self.session is None:
-            raise sa_exc.ResourceClosedError(error_msg)
+                else:
+                    raise sa_exc.InvalidRequestError(
+                        "This Session's transaction has been rolled back "
+                        "by a nested rollback() call.  To begin a new "
+                        "transaction, issue Session.rollback() first."
+                        )
+        elif self._state is CLOSED:
+            raise sa_exc.ResourceClosedError(closed_msg)
 
     @property
     def _is_transaction_boundary(self):
         return self.nested or not self._parent
 
     def connection(self, bindkey, **kwargs):
-        self._assert_is_active()
+        self._assert_active()
         bind = self.session.get_bind(bindkey, **kwargs)
         return self._connection_for_bind(bind)
 
     def _begin(self, nested=False):
-        self._assert_is_active()
+        self._assert_active()
         return SessionTransaction(
             self.session, self, nested=nested)
 
 
 
     def _connection_for_bind(self, bind):
-        self._assert_is_active()
+        self._assert_active()
 
         if bind in self._connections:
             return self._connections[bind][0]
     def prepare(self):
         if self._parent is not None or not self.session.twophase:
             raise sa_exc.InvalidRequestError(
-                "Only root two phase transactions of can be prepared")
+                "'twophase' mode not enabled, or not root transaction; "
+                "can't prepare.")
         self._prepare_impl()
 
     def _prepare_impl(self):
-        self._assert_is_active()
+        self._assert_active()
         if self._parent is None or self.nested:
             self.session.dispatch.before_commit(self.session)
 
         self._state = PREPARED
 
     def commit(self):
-        self._assert_is_open()
+        self._assert_active(prepared_ok=True)
         if self._state is not PREPARED:
             self._prepare_impl()
 
             for t in set(self._connections.values()):
                 t[1].commit()
 
+            self._state = COMMITTED
             self.session.dispatch.after_commit(self.session)
 
             if self.session._enable_transaction_accounting:
         return self._parent
 
     def rollback(self, _capture_exception=False):
-        self._assert_is_open()
+        self._assert_active(prepared_ok=True, rollback_ok=True)
 
         stx = self.session.transaction
         if stx is not self:
         sess = self.session
 
         if self.session._enable_transaction_accounting and \
-            not sess._is_clean():
+                not sess._is_clean():
             # if items were added, deleted, or mutated
             # here, we need to re-restore the snapshot
             util.warn(
         self.session.transaction = self._parent
         if self._parent is None:
             for connection, transaction, autoclose in \
-                set(self._connections.values()):
+                    set(self._connections.values()):
                 if autoclose:
                     connection.close()
                 else:
                     transaction.close()
 
-        self._state = DEACTIVE
+        self._state = CLOSED
         if self.session.dispatch.after_transaction_end:
             self.session.dispatch.after_transaction_end(self.session, self)
 
         return self
 
     def __exit__(self, type, value, traceback):
-        self._assert_is_open("Cannot end transaction context. The transaction "
-                                    "was closed from within the context")
+        self._assert_active(prepared_ok=True)
         if self.session.transaction is None:
             return
         if type is None:

test/orm/test_transaction.py

                               sess.begin, subtransactions=True)
         sess.close()
 
-    def test_no_sql_during_prepare(self):
+    def test_no_sql_during_commit(self):
         sess = create_session(bind=testing.db, autocommit=False)
 
         @event.listens_for(sess, "after_commit")
         def go(session):
             session.execute("select 1")
         assert_raises_message(sa_exc.InvalidRequestError,
-                    "This session is in 'prepared' state, where no "
-                    "further SQL can be emitted until the "
-                    "transaction is fully committed.",
+                    "This session is in 'committed' state; no further "
+                    "SQL can be emitted within this transaction.",
                     sess.commit)
 
+    def test_no_sql_during_prepare(self):
+        sess = create_session(bind=testing.db, autocommit=False, twophase=True)
+
+        sess.prepare()
+
+        assert_raises_message(sa_exc.InvalidRequestError,
+                    "This session is in 'prepared' state; no further "
+                    "SQL can be emitted within this transaction.",
+                    sess.execute, "select 1")
+
+    def test_no_prepare_wo_twophase(self):
+        sess = create_session(bind=testing.db, autocommit=False)
+
+        assert_raises_message(sa_exc.InvalidRequestError,
+                    "'twophase' mode not enabled, or not root "
+                    "transaction; can't prepare.",
+                    sess.prepare)
+
+    def test_closed_status_check(self):
+        sess = create_session()
+        trans = sess.begin()
+        trans.rollback()
+        assert_raises_message(
+                sa_exc.ResourceClosedError,
+                "This transaction is closed",
+                trans.rollback
+        )
+        assert_raises_message(
+                sa_exc.ResourceClosedError,
+                "This transaction is closed",
+                trans.commit
+        )
+
+    def test_deactive_status_check(self):
+        sess = create_session()
+        trans = sess.begin()
+        trans2 = sess.begin(subtransactions=True)
+        trans2.rollback()
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "This Session's transaction has been rolled back by a nested "
+            "rollback\(\) call.  To begin a new transaction, issue "
+            "Session.rollback\(\) first.",
+            trans.commit
+        )
+
+    def test_deactive_status_check_w_exception(self):
+        sess = create_session()
+        trans = sess.begin()
+        trans2 = sess.begin(subtransactions=True)
+        try:
+            raise Exception("test")
+        except:
+            trans2.rollback(_capture_exception=True)
+        assert_raises_message(
+            sa_exc.InvalidRequestError,
+            "This Session's transaction has been rolled back due to a "
+            "previous exception during flush. To begin a new transaction "
+            "with this Session, first issue Session.rollback\(\). "
+            "Original exception was: test",
+            trans.commit
+        )
+
     def _inactive_flushed_session_fixture(self):
         users, User = self.tables.users, self.classes.User
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.