Source

django-doctest / django_doctest_tests / src / django_doctest.py

from django.conf import settings
from django.core import management
from django.db.models import get_app, get_apps
from django.test.client import Client
from urlparse import urlparse, urljoin


class TestBase(object):
    def assertRedirects(self, response, expected_path, status_code=302, target_status_code=200, \
                            base_path=None):
        """Assert that a response redirected to a specific URL, and that the
        redirect URL can be loaded.

        """
        self.assertEqual(response.status_code, status_code,
            "Response didn't redirect as expected: Reponse code was %d (expected %d), in '%s'" %
                (response.status_code, status_code, response.request['PATH_INFO']))
        if not response.has_header('Location'):
            return
        scheme, netloc, path, params, query, fragment = urlparse(response['Location'])
        self.assertEqual(path, expected_path,
            "Response redirected to '%s', expected '%s'" % (path, expected_path))
        path = urljoin(base_path or "", path)
        redirect_response = self.client.get(path)
        self.assertEqual(redirect_response.status_code, target_status_code,
            "Couldn't retrieve redirection page '%s': response code was %d (expected %d)" %
                (path, redirect_response.status_code, target_status_code))

    def assertContains(self, response, text, count=1, status_code=200, url="response"):
        """Assert that a response indicates that a page was retreived successfully,
        (i.e., the HTTP status code was as expected), and that ``text`` occurs ``count``
        times in the content of the response.

        """
        self.assertEqual(response.status_code, status_code,
            "Couldn't retrieve page: Response code was %d (expected %d)'" %
                (response.status_code, status_code))
        real_count = response.content.count(text)
        self.assertEqual(real_count, count,
            "Found %d instances of '%s' in %s (expected %d)" % (real_count, text, url, count))

    def assertFormError(self, response, form, field, errors):
        "Assert that a form used to render the response has a specific field error"
        if not response.context:
            self.fail('Response did not use any contexts to render the response')

        # If there is a single context, put it into a list to simplify processing
        if not isinstance(response.context, list):
            contexts = [response.context]
        else:
            contexts = response.context

        # If a single error string is provided, make it a list to simplify processing
        if not isinstance(errors, list):
            errors = [errors]

        # Search all contexts for the error.
        found_form = False
        for i,context in enumerate(contexts):
            if form in context:
                found_form = True
                for err in errors:
                    if field:
                        if field in context[form].errors:
                            self.failUnless(err in context[form].errors[field],
                            "The field '%s' on form '%s' in context %d does not contain the error '%s' (actual errors: %s)" %
                                (field, form, i, err, list(context[form].errors[field])))
                        elif field in context[form].fields:
                            self.fail("The field '%s' on form '%s' in context %d contains no errors" %
                                (field, form, i))
                        else:
                            self.fail("The form '%s' in context %d does not contain the field '%s'" % (form, i, field))
                    else:
                        self.failUnless(err in context[form].non_field_errors(),
                            "The form '%s' in context %d does not contain the non-field error '%s' (actual errors: %s)" %
                                (form, i, err, list(context[form].non_field_errors())))
        if not found_form:
            self.fail("The form '%s' was not used to render the response" % form)

    def assertTemplateUsed(self, response, template_name):
        "Assert that the template with the provided name was used in rendering the response"
        if isinstance(response.template, list):
            template_names = [t.name for t in response.template]
            self.failUnless(template_name in template_names,
                "Template '%s' was not one of the templates used to render the response. Templates used: %s" %
                    (template_name, template_names))
        elif response.template:
            self.assertEqual(template_name, response.template.name,
                "Template '%s' was not used to render the response. Actual template was '%s', in '%s'" %
                    (template_name, response.template.name, response.request['PATH_INFO']))
        else:
            self.fail('No templates used to render the response')

    def assertTemplateNotUsed(self, response, template_name):
        "Assert that the template with the provided name was NOT used in rendering the response"
        if isinstance(response.template, list):
            self.failIf(template_name in [t.name for t in response.template],
                "Template '%s' was used unexpectedly in rendering the response" % template_name)
        elif response.template:
            self.assertNotEqual(template_name, response.template.name,
                "Template '%s' was used unexpectedly in rendering the response" % template_name)


    def fail(self, msg=None):
        """Fail immediately, with the given message."""
        print msg

    def failIf(self, expr, msg=None):
        "Fail the test if the expression is true."
        if expr: print msg

    def failUnless(self, expr, msg=None):
        """Fail the test unless the expression is true."""
        if not expr: print msg

    def failUnlessRaises(self, excClass, callableObj, *args, **kwargs):
        """Fail unless an exception of class excClass is thrown
           by callableObj when invoked with arguments args and keyword
           arguments kwargs. If a different type of exception is
           thrown, it will not be caught, and the test case will be
           deemed to have suffered an error, exactly as for an
           unexpected exception.
        """
        try:
            callableObj(*args, **kwargs)
        except excClass:
            return
        else:
            if hasattr(excClass,'__name__'): excName = excClass.__name__
            else: excName = str(excClass)
            print "%s not raised" % excName

    def failUnlessEqual(self, first, second, msg=None):
        """Fail if the two objects are unequal as determined by the '=='
           operator.
        """
        if not first == second:
            print (msg or '%r != %r' % (first, second))

    def failIfEqual(self, first, second, msg=None):
        """Fail if the two objects are equal as determined by the '=='
           operator.
        """
        if first == second:
            print (msg or '%r == %r' % (first, second))

    def failUnlessAlmostEqual(self, first, second, places=7, msg=None):
        """Fail if the two objects are unequal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero.

           Note that decimal places (from zero) are usually not the same
           as significant digits (measured from the most signficant digit).
        """
        if round(second-first, places) != 0:
            print (msg or '%r != %r within %r places' % (first, second, places))

    def failIfAlmostEqual(self, first, second, places=7, msg=None):
        """Fail if the two objects are equal as determined by their
           difference rounded to the given number of decimal places
           (default 7) and comparing to zero.

           Note that decimal places (from zero) are usually not the same
           as significant digits (measured from the most signficant digit).
        """
        if round(second-first, places) == 0:
            print (msg or '%r == %r within %r places' % (first, second, places))

    assertEqual = assertEquals = failUnlessEqual

    assertNotEqual = assertNotEquals = failIfEqual

    assertAlmostEqual = assertAlmostEquals = failUnlessAlmostEqual

    assertNotAlmostEqual = assertNotAlmostEquals = failIfAlmostEqual

    assertRaises = failUnlessRaises

    assert_ = assertTrue = failUnless

    assertFalse = failIf


def to_iter(args):
    if args:
        if not hasattr(args, "__iter__"):
            return (args,)
    return args

def reset(app_label=None, verbosity=0):
    app_labels = to_iter(app_label)

    if not app_labels:
        app_labels = [app.__name__.split('.')[-2] for app in get_apps()]

    if verbosity:
        print "Reset databases..."
    for app_label in app_labels:
        try:
            management.call_command("reset", app_label, interactive=False)
            if verbosity:
                print "  %s" % get_app(app_label).__name__
        except IndexError:
            pass

def loaddata(fixtures, verbosity=0):
    fixtures = fixtures or None
    if not fixtures:
        return
    fixtures = to_iter(fixtures)
    options = dict(
        verbosity = verbosity,
    )

    if fixtures:
        management.call_command("loaddata", *fixtures, **options)

def flush(verbosity=0):
    options = dict(
        interactive = False,
        verbosity = verbosity,
    )
    management.call_command("flush", **options)


class Test(TestBase):
    invalid_string = "TEMPLATE_STRING_IF_INVALID"

    def __init__(self, fixtures=None, auth=None, invalid_string=None, **extra):
        self.extra = extra
        self.fixtures = fixtures
        self.auth = auth
        if invalid_string is not None:
            self.invalid_string = invalid_string
        self.logined = None
        self.set_client()
        if self.auth:
            self.login()
        self.c = self.client

    def login(self, auth=None):
        if auth:
            _auth = auth
        else:
            _auth = self.auth
        self.logined = self.client.login(**_auth)

    def logout(self):
        #TODO
        #http://code.djangoproject.com/changeset/5916?new_path=django%2Ftrunk%2Fdjango%2Ftest
        self.set_client()
        self.logined = None

    def set_client(self):
        _extra = {}
        if hasattr(self, 'cookies'):
            _extra["HTTP_COOKIE"] = self.cookies
        if hasattr(self, 'ipaddr'):
            _extra["REMOTE_ADDR"] = self.ipaddr
        _extra.update(self.extra)
        self.client = Client(**_extra)

    def refresh_data(self, app_label=None, fixtures=None, verbosity=0):
        reset(app_label, verbosity)
        if (not fixtures) and hasattr(self, 'fixtures'):
            fixtures = self.fixtures
        loaddata(fixtures, verbosity)

    redirect_status_code = (301, 302)

    def assertResponses(self, urls_dict):
        for key, value in urls_dict.items():
            if value[0] in self.redirect_status_code:
                base_path, status_code, expected_path = key, value[0], value[1]
                response = self.client.get(base_path)
                self.assertRedirects(response, expected_path, status_code=status_code, \
                                        base_path=base_path)
            elif isinstance(value[0], int):
                base_path, status_code = key, value[0]
                if self.invalid_string:
                    org_invalid_string = settings.TEMPLATE_STRING_IF_INVALID
                    settings.TEMPLATE_STRING_IF_INVALID = self.invalid_string
                response = self.client.get(base_path)
                self.assertEqual(response.status_code, status_code,
                    "Response didn't redirect as expected: Reponse code was %d (expected %d), in '%s'" %
                        (response.status_code, status_code, key))
                if self.invalid_string:
                    self.assertContains(response, self.invalid_string, count=0, url=base_path)
                    settings.TEMPLATE_STRING_IF_INVALID = org_invalid_string
                try:
                    if not value[1]:
                        continue
                except IndexError:
                    continue
                if isinstance(value[1], str):
                    self.assertTemplateUsed(response, value[1])
                else:
                    [self.assertTemplateUsed(response, template_name) for template_name in value[1]]
            else:
                print "Bad test. '%s': %s" % (key, value)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.