Mikhail Korobov avatar Mikhail Korobov committed 515a955

Fix transaction handling for django 1.5. Thanks @mbraak.

Comments (0)

Files changed (3)

 * bigkevmcd
 * Jeroen Vloothuis
 * Tai Lee
-* David Winterbottom
+* David Winterbottom
+* Marco Braak

django_webtest/__init__.py

 from django.test.client import store_rendered_templates
 from django.utils.functional import curry
 from django.utils.importlib import import_module
-from webtest import TestApp
-from webtest.compat import to_string
-
-from django_webtest.middleware import DjangoWsgiFix
-from django_webtest.response import DjangoWebtestResponse
-
+from django.core import signals
+from django.db import close_connection
 try:
     from django.core.servers.basehttp import AdminMediaHandler as StaticFilesHandler
 except ImportError:
     from django.contrib.staticfiles.handlers import StaticFilesHandler
 
+from webtest import TestApp
+from webtest.compat import to_string
+
+from django_webtest.response import DjangoWebtestResponse
+
 
 class DjangoTestApp(TestApp):
 
         super(DjangoTestApp, self).__init__(self.get_wsgi_handler(), extra_environ, relative_to)
 
     def get_wsgi_handler(self):
-        return DjangoWsgiFix(StaticFilesHandler(WSGIHandler()))
+        return StaticFilesHandler(WSGIHandler())
 
     def _update_environ(self, environ, user):
         if user:
         return environ
 
     def do_request(self, req, status, expect_errors):
-        req.environ.setdefault('REMOTE_ADDR', '127.0.0.1')
 
-        # is this a workaround for https://code.djangoproject.com/ticket/11111 ?
-        req.environ['REMOTE_ADDR'] = to_string(req.environ['REMOTE_ADDR'])
-        req.environ['PATH_INFO'] = to_string(req.environ['PATH_INFO'])
+        # Django closes the database connection after every request;
+        # this breaks the use of transactions in your tests.
+        signals.request_finished.disconnect(close_connection)
 
-        # Curry a data dictionary into an instance of the template renderer
-        # callback function.
-        data = {}
-        on_template_render = curry(store_rendered_templates, data)
-        template_rendered.connect(on_template_render)
+        try:
+            req.environ.setdefault('REMOTE_ADDR', '127.0.0.1')
 
-        response = super(DjangoTestApp, self).do_request(req, status, expect_errors)
+            # is this a workaround for https://code.djangoproject.com/ticket/11111 ?
+            req.environ['REMOTE_ADDR'] = to_string(req.environ['REMOTE_ADDR'])
+            req.environ['PATH_INFO'] = to_string(req.environ['PATH_INFO'])
 
-        # Add any rendered template detail to the response.
-        # If there was only one template rendered (the most likely case),
-        # flatten the list to a single element.
-        def flattend(detail):
-            if len(data[detail]) == 1:
-                return data[detail][0]
-            return data[detail]
+            # Curry a data dictionary into an instance of the template renderer
+            # callback function.
+            data = {}
+            on_template_render = curry(store_rendered_templates, data)
+            template_rendered.connect(on_template_render)
 
-        response.context = None
-        response.template = None
-        response.templates = data.get('templates', None)
+            response = super(DjangoTestApp, self).do_request(req, status, expect_errors)
 
-        if data.get('context'):
-            response.context = flattend('context')
+            # Add any rendered template detail to the response.
+            # If there was only one template rendered (the most likely case),
+            # flatten the list to a single element.
+            def flattend(detail):
+                if len(data[detail]) == 1:
+                    return data[detail][0]
+                return data[detail]
 
-        if data.get('template'):
-            response.template = flattend('template')
-        elif data.get('templates'):
-            response.template = flattend('templates')
+            response.context = None
+            response.template = None
+            response.templates = data.get('templates', None)
 
-        response.__class__ = DjangoWebtestResponse
-        return response
+            if data.get('context'):
+                response.context = flattend('context')
+
+            if data.get('template'):
+                response.template = flattend('template')
+            elif data.get('templates'):
+                response.template = flattend('templates')
+
+            response.__class__ = DjangoWebtestResponse
+            return response
+        finally:
+            signals.request_finished.connect(close_connection)
+
 
     def get(self, url, params=None, headers=None, extra_environ=None,
             status=None, expect_errors=False, user=None, auto_follow=False,

django_webtest/middleware.py

 # -*- coding: utf-8 -*-
 from django.contrib.auth.middleware import RemoteUserMiddleware
 from django.core.exceptions import ImproperlyConfigured
-from django.core import signals
-from django.db import close_connection
 from django.contrib import auth
 
 class WebtestUserMiddleware(RemoteUserMiddleware):
 class DisableCSRFCheckMiddleware(object):
     def process_request(self, request):
         request._dont_enforce_csrf_checks = True
-
-
-class DjangoWsgiFix(object):
-    """Django closes the database connection after every request;
-    this breaks the use of transactions in your tests. This wraps
-    around Django's WSGI interface and will disable the critical
-    signal handler for every request served.
-
-    Note that we really do need to do this individually a every
-    request, not just once when our WSGI hook is installed, since
-    Django's own test client does the same thing; it would reinstall
-    the signal handler if used in combination with us.
-
-    From django-test-utils.
-    Note: that's WSGI middleware, not django's.
-    """
-    def __init__(self, app):
-        self.app = app
-
-    def __call__(self, environ, start_response):
-        signals.request_finished.disconnect(close_connection)
-        try:
-            return self.app(environ, start_response)
-        finally:
-            signals.request_finished.connect(close_connection)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.