Source

nose-xml-skippy / nose_xml.py

Full commit
import os
import traceback

from nose import plugins
from nose.exc import SkipTest
from time import time

def xmlsafe(s, encoding="utf-8"):
    if isinstance(s, unicode):
        s = s.encode(encoding)
    s = str(s)
    for src, rep in [('&', '&', ),
                     ('<', '&gt;', ),
                     ('>', '&lt;', ),
                     ('"', '&quot;', ),
                     ("'", '&quot;', ),
                     ]:
        s = s.replace(src, rep)
    return s

class Xunit(plugins.Plugin):
    name = 'xunit'
    score = 2000
    encoding = 'UTF-8'
    
    def _xmlsafe(self, s):
        return xmlsafe(s, encoding=self.encoding)
    
    def options(self, parser, env=os.environ):
        plugins.Plugin.options(self, parser, env)
        parser.add_option('--xunit-file', action='store',
                          dest='xunit_file',
                          default=env.get('NOSE_XUNIT_FILE', 'nosetests.xml'),
                          help=("Path to xml file to store the xunit report in. "
                                "[NOSE_XUNIT_FILE]"))

    def configure(self, options, config):
        plugins.Plugin.configure(self, options, config)
        if self.enabled:
            self.stats = {'errors': 0,
                          'failures': 0,
                          'passes': 0,
                          'skipped': 0
                          }
            self.errorlist = []
            self.error_report_file = open(options.xunit_file, 'w')

    def report(self, stream):
        self.stats['encoding'] = self.encoding
        self.stats['total'] = (self.stats['errors'] + self.stats['failures']
                               + self.stats['passes'] + self.stats['skipped'])
        self.error_report_file.write('<?xml version="1.0" encoding="%(encoding)s"?>'
                                     '<testsuite name="nosetests" tests="%(total)d" '
                                     'errors="%(errors)d" failures="%(failures)d" skip="%(skipped)d">' % self.stats)
        self.error_report_file.write(''.join(self.errorlist))
        self.error_report_file.write('</testsuite>')
        self.error_report_file.close()

    def startTest(self, test):
        self._timer = time()

    def addError(self, test, err, capt=None):
        taken = time() - self._timer
        if issubclass(err[0], SkipTest):
            self.stats['skipped'] +=1
            return
        tb = ''.join(traceback.format_exception(*err))
        self.stats['errors'] += 1
        id = test.id()
        self.errorlist.append('<testcase classname="%(cls)s" name="%(name)s" time="%(taken)d">'
                              '<error type="%(errtype)s">%(tb)s</error></testcase>' %
                              {'cls': self._xmlsafe('.'.join(id.split('.')[:-1])),
                               'name': self._xmlsafe(id),
                               'errtype': self._xmlsafe(err[0]),
                               'tb': self._xmlsafe(tb),
                               'taken': taken,
                               })

    def addFailure(self, test, err, capt=None, tb_info=None):
        taken = time() - self._timer
        tb = ''.join(traceback.format_exception(*err))
        self.stats['failures'] += 1
        id = test.id()
        self.errorlist.append('<testcase classname="%(cls)s" name="%(name)s" time="%(taken)d">'
                              '<failure type="%(errtype)s">%(tb)s</failure></testcase>' %
                              {'cls': self._xmlsafe('.'.join(id.split('.')[:-1])),
                               'name': self._xmlsafe(id),
                               'errtype': self._xmlsafe(err[0]),
                               'tb': self._xmlsafe(tb),
                               'taken': taken,
                               })

    def addSuccess(self, test, capt=None):
        taken = time() - self._timer
        self.stats['passes'] += 1
        id = test.id()
        self.errorlist.append('<testcase classname="%(cls)s" name="%(name)s" time="%(taken)d" />' %
                              {'cls': self._xmlsafe('.'.join(id.split('.')[:-1])),
                               'name': self._xmlsafe(id),
                               'taken': taken,
                               })