Source

strainer / strainer / operators.py

from xhtmlify import xhtmlify, XMLParsingError, ValidationError
from xml.etree import ElementTree as etree
from xml.parsers.expat import ExpatError
import copy, re
from pprint import pformat, pprint
try:
    from simplejson import loads
except ImportError:
    from json import loads

from nose.tools import *
from almostequal import approx_equal
import strainer.log as log

log = log.log

def remove_whitespace_nodes(node):
    new_node = copy.copy(node)
    new_node._children = []
    if new_node.text and new_node.text.strip() == '':
        new_node.text = ''
    if new_node.tail and new_node.tail.strip() == '':
        new_node.tail = ''
    for child in node.getchildren():
        if child is not None:
            child = remove_whitespace_nodes(child)
        new_node.append(child)
    return new_node

def remove_namespace(doc):
    """Remove namespace in the passed document in place."""
    for elem in doc.getiterator():
        match = re.match('(\{.*\})(.*)', elem.tag)
        if match:
            elem.tag = match.group(2)

def replace_escape_chars(needle):
    needle = needle.replace(' ', ' ')
    needle = needle.replace(u'\xa0', ' ')
    return needle

def normalize_to_xhtml(needle):
    # We still need this, when in a webtest response,   gets replaced with \xa0, and xhtmlify can't handle non-acii
    needle = replace_escape_chars(needle)
    #first, we need to make sure the needle is valid html
    needle = xhtmlify(needle)
    try:
        needle_node = etree.fromstring(needle)
    except ExpatError, e:
        raise XMLParsingError('Could not parse %s into xml. %s'%(needle, e.args[0]))
    needle_node = remove_whitespace_nodes(needle_node)
    remove_namespace(needle_node)
    needle_s = etree.tostring(needle_node)
    return needle_s

def in_xhtml(needle, haystack):
    try:
        needle_s = normalize_to_xhtml(needle)
    except ValidationError, e:
        raise XMLParsingError('Could not parse needle: %s into xml. %s'%(needle, e.message))
    try:
        haystack_s = normalize_to_xhtml(haystack)
    except ValidationError, e:
        raise XMLParsingError('Could not parse haystack: %s into xml. %s'%(haystack, e.message))
    return needle_s in haystack_s

def eq_xhtml(needle, haystack, wrap=False):
    if wrap:
        needle = '<div id="wrapper">%s</div>'
        haystack = '<div id="wrapper">%s</div>'
    try:
        needle_s = normalize_to_xhtml(needle)
    except ValidationError, e:
        raise XMLParsingError('Could not parse needle: %s into xml. %s'%(needle, e.message))
    try:
        haystack_s = normalize_to_xhtml(haystack)
    except ValidationError, e:
        raise XMLParsingError('Could not parse haystack: %s into xml. %s'%(haystack, e.message))
    return needle_s == haystack_s

def assert_in_xhtml(needle, haystack):
    """
    assert that one xhtml stream can be found within another
    """
    assert in_xhtml(needle, haystack), "%s not found in %s"%(needle, haystack)

def assert_eq_xhtml(needle, haystack, wrap=False):
    """
    assert that one xhtml stream equals another
    """
    assert eq_xhtml(needle, haystack, wrap), "%s \n --- does not equal ---\n%s"%(needle, haystack)

def assert_raises(exc, method, *args, **kw):
    try:
        method(*args, **kw)
    except exc, e:
        return e
    else:
        raise AssertionError('%s() did not raise %s' % (method.__name__, exc.__name__))

def num_eq(one, two):
    assert type(one)==type(two), 'The types %s and %s do not match' % (type(one), type(two))
    eq_(one, two, 'The values %s and %s do not equal' % (one, two))

def neq_(one, two, msg = None):
    """Shorthand for 'assert a != b, "%r == %r" % (a, b)
    """
    assert a != b, msg or "%r == %r" % (a, b)

def eq_pprint(a, b, msg=None):
    if a != b:
        log.error(msg)
        return False
    return True

def _eq_list(ca, cb, ignore=None):
    r = eq_pprint(len(ca), len(cb), "The lengths of the lists are different %s != %s" % (str(ca), str(cb)))
    if not r:
        return False
    for i, v in enumerate(ca):
        if isinstance(v, dict):
            if not _eq_dict(ca[i], cb[i], ignore=ignore):
                return False
        elif isinstance(v, list):
            if not _eq_list(ca[i], cb[i], ignore=ignore):
                return False
        else:
            if not eq_pprint(ca[i], cb[i]):
                return False
    return True

def _eq_dict(ca, cb, ignore=None):
    # assume ca and cb can be destructively modified
    if ignore:
        for key in ignore:
            if key in ca:
                del ca[key]
            if key in cb:
                del cb[key]

    #this needs to be recursive so we can '&ignore'-out ids anywhere in a json stream
    for key in set(ca.keys() + cb.keys()):
        if key not in ca:
            log.error('%s!= %s\n key "%s" not in first argument' %(ca, cb, key))
            return False
        if key not in cb:
            log.error('%s!= %s\n key "%s" not in second argument' %(ca, cb, key))
            return False
        
        v1 = ca[key]
        v2 = cb[key]
        log.info('Comparing values for key: %s', key)
        if v1 == '&ignore' or v2 == '&ignore':
            log.info('Ignored comparison for key: %s', key)
            continue
        if not isinstance(v2, basestring) and isinstance(v1, basestring):
            if not eq_pprint(type(v1), type(v2)):
                log.error('The types of values for "%s" do not match (%s vs. %s)' %(key, v1, v2))
                return False
        if isinstance(v1, list):
            if not _eq_list(v1, v2, ignore=ignore):
                return False
        elif isinstance(v1, dict):
            if not _eq_dict(v1, v2, ignore=ignore):
                return False
        elif isinstance(v1, float) and isinstance(v2, float):
            if not approx_equal(v1, v2):
                log.error('The values for "%s" do not match (%.30f vs. %.30f)' %(key, v1, v2))
                return False
        else:
            if not v1 == v2:
                log.error('The values for "%s" do not match (%s vs. %s)' %(key, v1, v2))
                return False
    return True

def eq_dict(a, b, ignore=None):
    #Make a copy as our search for ignored values is destructive
    ca = copy.deepcopy(a)
    cb = copy.deepcopy(b)
                
    return _eq_dict(ca, cb, ignore=ignore)

def eq_json(a, b):
    if isinstance(a, basestring):
        a = loads(a)
    if isinstance(b, basestring):
        b = loads(b)
        
    return eq_dict(a, b)


__all__ = [_key for _key in locals().keys() if not _key.startswith('_')]