1. tlynn
  2. strainer-py3

Commits

tlynn  committed 9ac7146

Fix Python 3 support.

  • Participants
  • Parent commits ea3a786
  • Branches default

Comments (0)

Files changed (7)

File strainer/operators.py

View file
-from .xhtmlify import xhtmlify, XMLParsingError, ValidationError, PY3
+from .xhtmlify import xhtmlify, XMLParsingError, ValidationError
 from xml.etree import ElementTree as etree
 from xml.parsers.expat import ExpatError
 import copy
     except ValidationError as e:
         raise XMLParsingError(
             'Could not parse needle: %s into xml. %s' %
-            (needle, e.args[0] if PY3 else e.message))
+            (needle, e.args[0]))
     try:
         haystack_s = normalize_to_xhtml(haystack)
     except ValidationError as e:
         raise XMLParsingError(
             'Could not parse haystack: %s into xml. %s' %
-            (haystack, e.args[0] if PY3 else e.message))
+            (haystack, e.args[0]))
     return needle_s in haystack_s
 
 
     except ValidationError as e:
         raise XMLParsingError(
             'Could not parse needle: %s into xml. %s' %
-            (needle, e.args[0] if PY3 else e.message))
+            (needle, e.args[0]))
     try:
         haystack_s = normalize_to_xhtml(haystack)
     except ValidationError as e:
         raise XMLParsingError(
             'Could not parse haystack: %s into xml. %s' %
-            (haystack, e.args[0] if PY3 else e.message))
+            (haystack, e.args[0]))
     return needle_s == haystack_s
 
 

File strainer/validate.py

View file
 
 from pkg_resources import resource_string
 from strainer.doctypes import *
-from strainer.xhtmlify import PY3
 
 
 __all__ = ['validate_xhtml', 'validate_xhtml_fragment', 'XHTMLSyntaxError',
         tline = doctype.count('\n')
         message = re.sub(r'line (\d+)',
                          lambda m: 'line %s' % (int(m.group(1)) - tline),
-                         e.args[0] if PY3 else e.message)
+                         e.args[0])
         raise XHTMLSyntaxError(message)
 
 
         # relative to the fragment.
         message = re.sub(r'line (\d+)',
                          lambda m: 'line %s' % (int(m.group(1)) - tline),
-                         e.args[0] if PY3 else e.message)
+                         e.args[0])
         raise XHTMLSyntaxError(message)
 
 

File strainer/wellformed.py

View file
     import htmlentitydefs
 
 from xml.sax._exceptions import SAXParseException
-from strainer.xhtmlify import PY3
+
 
 __all__ = ['is_wellformed_xml', 'is_wellformed_xhtml']
 
                 column -= len(doctype) - (doctype.rfind('\n') + 1)
             # Convert column to 1-based indexing
             record_error('line %d, column %d: %s' % (
-                line, column + 1, e.args[0] if PY3 else e.message
+                line, column + 1, e.args[0]
             ))
         return False
 

File strainer/xhtmlify.py

View file
     encodingdecl = ''
     if encoding is not None:
         EncName_re = re.compile(r'[A-Za-z][A-Za-z0-9._-]*\Z')  # from XML spec
+        if isinstance(encoding, six.binary_type):
+            encoding = encoding.decode('ascii', 'replace')
         if isinstance(encoding, six.text_type) and EncName_re.match(encoding):
             encodingdecl = ' encoding="%s"' % encoding
         else:
             # Don't tell them expected format, guessing won't help
             raise ValidationError('Bad standalone value in XML declaration',
                                   0, 1, 1, [])
-    return '<?xml version="%s"%s%s ?>' % (version, encodingdecl, sddecl)
+    return six.u('<?xml version="%s"%s%s ?>') % (version, encodingdecl, sddecl)
 
 
 def fix_xmldecl(xml, encoding=None, add_encoding=False, default_version='1.0'):
     # This code started as a copy of sniff_encoding(), which follows the
     # XML spec.  This version uses a more lenient parser.
     EOS = r'\Z'  # end of string regexp
+    EOB = six.b(EOS)  # end of bytes regexp
+    bom = chr(0xfeff) if PY3 else unichr(0xfeff)
     starts_utf16_re = re.compile('utf[_-]?16', re.IGNORECASE)
     bomless_utf16_re = re.compile('utf[_-]?16[_-]?[bl]e\Z', re.IGNORECASE)
     unicode_input = isinstance(xml, six.text_type)
             if not unicode_input and not (
                 xml.startswith(codecs.BOM_UTF16_LE) or
                 xml.startswith(codecs.BOM_UTF16_BE)):
-
-                xml = six.u('\ufeff').encode(encoding) + xml
+                xml = bom.encode(encoding) + xml
             elif unicode_input and bomless_utf16_re.match(encoding):
-                xml = six.u('\ufeff') + xml
+                xml = bom + xml
             # "else: pass"; Python adds the BOM when encoding unicode as UTF-16
     if unicode_input:
         if encoding:
-            xmlstr = xml.encode(encoding)
+            xmlbytes = xml.encode(encoding)
         else:
-            xmlstr = xml.encode('UTF-8')
+            xmlbytes = xml.encode('UTF-8')
     else:
-        xmlstr = xml
+        xmlbytes = xml
     if encoding:
         enc = encoding
     else:
-        enc = sniff_bom_encoding(xmlstr)
+        enc = sniff_bom_encoding(xmlbytes)
     if unicode_input:
-        xml = xmlstr
+        xml = xmlbytes
+        def decode(s):
+            result = s.decode(enc, 'strict')
+            if result.startswith(bom):
+                result = result[1:]
+            return result
+    else:
+        decode = lambda s: s
     # We must use an encoder to handle utf_8_sig properly.
     encode = codecs.lookup(enc).incrementalencoder().encode
     if bomless_utf16_re.match(enc):
         # These need a BOM prefix according to the spec but the default
         # Python encodings of that name don't provide one.
-        prefix = encode(six.u('\ufeff'))
+        prefix = encode(bom)
     else:
-        prefix = encode('')
+        prefix = encode(six.u(''))
     chars_we_need = ('''abcdefghijklmnopqrstuvwxyz'''
                      '''ABCDEFGHIJKLMNOPQRSTUVWXYZ'''
                      '''0123456789.-_ \t\r\n<?'"[]:()+*>''')
     assert encode(chars_we_need * 3) == encode(chars_we_need) * 3, enc
-    L = lambda s: re.escape(str(s) if PY3 else encode(s))  # encoded form of literal s
-    group = lambda s: '(%s)' % s
-    # optional = lambda s: '(?:%s)?' % s
-    oneof = lambda opts: '(?:%s)' % '|'.join([str(opt) for opt in opts])
+    L = lambda s: re.escape(encode(s))  # encoded form of literal s
+    group = lambda s: six.b('(') + s + six.b(')')
+    oneof = lambda opts: six.b('(?:') + six.b('|').join(opts) + six.b(')')
     charset = lambda s: oneof([L(c) for c in s])
-    all_until = lambda s: '(?:(?!%s).)*' % s
-    # caseless = lambda s: oneof([L(c.lower()) for c in s] +
-    #                            [L(c.upper()) for c in s])
+    all_until = lambda s: six.b('(?:(?!') + s + six.b(').)*')
     upper = charset('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
     lower = charset('abcdefghijklmnopqrstuvwxyz')
     digits = charset('0123456789')
     punc = charset('._-')
-    Name = '(?:%s%s*)' % (oneof([upper, lower]),
-                          oneof([upper, lower, digits, punc]))
-    Ss = charset(' \t\r\n\f') + '*'  # optional white space (inc. formfeed)
-    Sp = charset(' \t\r\n\f') + '+'  # required white space (inc. formfeed)
+    Name = six.b('(?:') + (oneof([upper, lower]) +
+                           oneof([upper, lower, digits, punc])) + six.b('*)')
+    Ss = charset(' \t\r\n\f') + six.b('*') # optional white space (inc. \f)
+    Sp = charset(' \t\r\n\f') + six.b('+') # required white space (inc. \f)
     VERSION = encode('version')
     ENCODING = encode('encoding')
     STANDALONE = encode('standalone')
-    StartDecl = ''.join([str(prefix), Ss, L('<'), Ss, L('?'), Ss,
-                         oneof([L('xml'), L('xmL'), L('xMl'), L('xML'),
-                                L('Xml'), L('XmL'), L('XMl'), L('XML')])])
-    Attr = ''.join([group(Sp), group(Name), group(''.join([Ss, L('='), Ss])),
+    joinbytes = six.b('').join
+    StartDecl = joinbytes([prefix, Ss, L('<'), Ss, L('?'), Ss,
+                           oneof([L('xml'), L('xmL'), L('xMl'), L('xML'),
+                                  L('Xml'), L('XmL'), L('XMl'), L('XML')])])
+    Attr = joinbytes([
+        group(Sp), group(Name), group(joinbytes([Ss, L('='), Ss])),
         oneof([
             group(L('"') + all_until(oneof([L('"'), L('<'), L('>')])) + L('"')),
             group(L("'") + all_until(oneof([L("'"), L('<'), L('>')])) + L("'")),
         ])
     ])
     Attr_re = re.compile(Attr, re.DOTALL)
-    EndDecl = ''.join([
-        group(Ss), oneof([''.join([L('?'), Ss, L('>')]), L('>')])
+    EndDecl = joinbytes([
+        group(Ss), oneof([joinbytes([L('?'), Ss, L('>')]), L('>')])
     ])
-    m = re.match(StartDecl, str(xml) if PY3 else xml)
+    m = re.match(StartDecl, xml)
     if m:
         pos = m.end()
         attrs = {}
                 if name in attrs:
                     pass  # TODO: warn: already got a value for xxx
                 elif name == VERSION:
-                    m3 = re.match(Ss + group(L("1.") + digits) + Ss + EOS,
+                    m3 = re.match(Ss + group(L("1.") + digits) + Ss + EOB,
                                   value)
                     if m3:
                         attrs[name] = wspace + name + eq + \
                     else:
                         pass  # TODO: warn: expected 1.x
                 elif name == ENCODING:
-                    m3 = re.match(Ss + group(Name) + Ss + EOS, value)
+                    m3 = re.match(Ss + group(Name) + Ss + EOB, value)
                     if m3:
                         attrs[name] = wspace + name + eq + \
                                 quotes + m3.group(1) + quotes
                                     L('Yes'), L('YeS'), L('YEs'), L('YES')])),
                             group(oneof([L('no'), L('nO'),
                                          L('No'), L('NO')]))
-                        ]) + Ss + EOS,
+                        ]) + Ss + EOB,
                         value)
                     if m3:
                         yes, no = m3.groups()
             attrs[ENCODING] = encode(" encoding='%s'" % enc)
         m4 = re.compile(EndDecl).match(xml, pos)
         if m4:
-            return (
+            return decode(
                 prefix + encode('<?xml') +
                 attrs.get(VERSION, encode(" version='%s'" % default_version)) +
-                (attrs.get(ENCODING) if ENCODING in attrs else '') +
-                (attrs.get(STANDALONE) if STANDALONE in attrs else '') +
+                (attrs.get(ENCODING) if ENCODING in attrs else six.b('')) +
+                (attrs.get(STANDALONE) if STANDALONE in attrs else six.b('')) +
                 m4.group(1).replace(encode('\f'), encode(' ')) +
                 encode('?>') + xml[m4.end():])
         else:
                     endpos = m5.end()
                 else:
                     endpos = m5.start()
-                return xml[:m.start()] + xml[endpos:]  # remove bad decl
+                return decode(xml[:m.start()] + xml[endpos:]) # remove bad decl
             else:
-                return ''  # unterminated, drop entire document (inc. BOM)
+                raise ValidationError("Unterminated XML declaration")
     if unicode_input:
-        xml = xml.decode(enc, 'strict')  # reverse the encoding done earlier
+        xml = decode(xml)  # reverse the encoding done earlier
     return xml  # no decl detected
 
 
         six.u('\u3001-\uD7FF\uF900-\uFDCF\uFDF0-\uFFFD]'))
     if len(six.u('\U00010000')) == 1:
         NameStartChar = NameStartChar[:-1] + six.u('\U00010000-\U000EFFFF]')
-    NameChar = NameStartChar[:-1] + six.u("0-9\xB7\u0300-\u036F\u203F-\u2040\-]")
+    NameChar = (NameStartChar[:-1] +
+                six.u("0-9\xB7\u0300-\u036F\u203F-\u2040\-]"))
     Name = NameStartChar + any(NameChar)
     Nmtoken = some(NameChar)
     quoted = oneof('"[^<>"]*"', "'[^<>']*'")
     if not encoding:
         encoding = sniff_encoding(html)
     unicode_input = isinstance(html, six.text_type)
-    if unicode_input and not PY3:
+    if unicode_input:
         html = html.encode(encoding, 'strict')
-    if not isinstance(html, str):
-        raise TypeError("Expected string, got %s" % type(html))
-    if not PY3:
-        html = html.decode(encoding, 'replace')
+    if not isinstance(html, six.binary_type):
+        raise TypeError("Expected %s, got %s" %
+                        (six.binary_type.__name__, type(html)))
+    html = html.decode(encoding, 'replace')
     # "in HTML, the Formfeed character (U+000C) is treated as white space"
     html = html.replace(six.u('\u000C'), six.u(' '))
     # Replace disallowed characters with U+FFFD (unicode replacement char)
         import sys
         if len(sys.argv) == 2:
             if sys.argv[1] == '-':
-                html = sys.stdin.read()
+                html = sys.stdin.buffer.read() if PY3 else sys.stdin.read()
             else:
-                html = open(sys.argv[1]).read()
+                html = open(sys.argv[1], 'rb').read()
         else:
             sys.exit('usage: %s HTMLFILE' % sys.argv[0])
     xhtml = xhtmlify(html)
     except ValidationError:
         print(xhtml)
         raise
-    xmlparse(re.sub('(?s)<!(?!\[).*?>', '', xhtml))  # ET can't handle <!...>
+    assert isinstance(xhtml, six.binary_type)
+    xhtml = xhtml.decode(sniff_encoding(xhtml))
+    xmlparse(re.sub(six.u('(?s)<!(?!\[).*?>'), '', xhtml))  # ET can't handle <!...>
     if len(sys.argv) == 2:
         sys.stdout.write(xhtml)
     return xhtml
        setting wrap to True or forced by setting wrap to False."""
     import xml.parsers.expat
     from xml.etree import ElementTree as ET
+    unicode_input = isinstance(snippet, six.text_type)
+    if not encoding:
+        encoding = sniff_encoding(snippet)
+    if unicode_input:
+        bom = chr(0xfeff) if six.PY3 else unichr(0xfeff)
+        bomless_snippet = snippet
+        if snippet.startswith(bom):
+            bomless_snippet = snippet[1:]
+        snippet_bytes = bomless_snippet.encode(encoding)
+        snippet_text = snippet
+    else:
+        snippet_bytes = snippet
+        snippet_text = snippet.decode(encoding, 'strict')
     if wrap is None:
-        wrap = (not snippet.startswith('<?xml') and
-                not snippet.startswith(six.u('\ufeff<?xml')))
+        wrap = not snippet_text.startswith((six.u('<?xml'),
+                                            six.u('\ufeff<?xml')))
+    if wrap:
+        input_text = six.u('<document>\n%s\n</document>') % snippet_text
+        input_bytes = input_text.encode(encoding)
+    else:
+        input_bytes = snippet_bytes
+        input_text = snippet_bytes.decode(encoding)
     try:
         if encoding:
             try:
                 parser = ET.XMLParser()  # old version
         else:
             parser = ET.XMLParser()  # let it use the standard algorithm
-        if wrap:  # XXX: not safe for non-ascii-ish encoded strs
-            input = '<document>\n%s\n</document>' % snippet
-        else:
-            input = snippet
-        if isinstance(snippet, six.text_type):
-            if not encoding:
-                encoding = sniff_encoding(snippet)
-            input = input.encode(encoding)
-        parser.feed(input)
+        parser.feed(input_bytes)
         parser.close()
     except xml.parsers.expat.ExpatError as e:
         lineno, offset = e.lineno, e.offset
         lineno -= 1
-        if lineno == input.count('\n'):  # last line => </document>
+        if lineno == input_text.count('\n'):  # last line => </document>
             lineno -= 1
             offset = len(snippet) - snippet.rfind('\n')
+        lineno = max(lineno, 1)
         message = re.sub(r'line \d+', 'line %d' % lineno,
-                         e.args[0] if PY3 else e.message, count=1)
+                         e.args[0], count=1)
         message = re.sub(r'column \d+', 'column %d' % offset,
                          message, count=1)
         parse_error = xml.parsers.expat.ExpatError(message)
 def sniff_encoding(xml):
     """Detects the XML encoding as per XML 1.0 section F.1."""
     if isinstance(xml, six.binary_type):
-        xmlstr = xml
+        xmlbytes = xml
     elif isinstance(xml, six.text_type):
-        xmlstr = xml.encode('utf-8')
+        xmlbytes = xml.encode('utf-8')
     else:
         raise TypeError('Expected a string, got %r' % type(xml))
-    enc = sniff_bom_encoding(xmlstr)
+    enc = sniff_bom_encoding(xmlbytes)
     # Now the fun really starts. We compile the encoded sniffer regexp.
     # We must use an encoder to handle utf_8_sig properly.
     encode = codecs.lookup(enc).incrementalencoder().encode
     prefix = encode('')  # any header such as a UTF-8 BOM
     if enc in ('utf_16_le', 'utf_16_be'):
         prefix = six.u('\ufeff').encode(enc)  # the standard approach fails
-    L = lambda s: re.escape(str(s) if PY3 else encode(s))  # encoded form of literal s
-    optional = lambda s: '(?:%s)?' % s
-    oneof = lambda opts: '(?:%s)' % '|'.join(
-            [str(opt) if PY3 else opt for opt in opts])
+    L = lambda s: re.escape(encode(s))  # encoded form of literal s
+    optional = lambda s: six.b('(?:') + s + six.b(')?')
+    oneof = lambda opts: six.b('(?:') + six.b('|').join(opts) + six.b(')')
     charset = lambda s: oneof([L(c) for c in s])
     upper = charset('ABCDEFGHIJKLMNOPQRSTUVWXYZ')
     lower = charset('abcdefghijklmnopqrstuvwxyz')
     digit = charset('0123456789')
-    digits = digit + '+'
+    digits = digit + six.b('+')
     punc = charset('._-')
-    name = '(?:%s%s*)' % (oneof([upper, lower]),
-                          oneof([upper, lower, digit, punc]))
-    Ss = charset(' \t\r\n') + '*'  # optional white space
-    Sp = charset(' \t\r\n') + '+'  # required white space
-    Eq = ''.join([Ss, L('='), Ss])
-    VersionInfo = ''.join([
+    name = six.b('(?:') + (oneof([upper, lower]) +
+                           oneof([upper, lower, digit, punc])) + six.b('*)')
+    Ss = charset(' \t\r\n') + six.b('*')  # optional white space
+    Sp = charset(' \t\r\n') + six.b('+')  # required white space
+    joinbytes = six.b('').join
+    Eq = joinbytes([Ss, L('='), Ss])
+    VersionInfo = joinbytes([
         Sp,
         L('version'),
         Eq,
             L('"1.') + digits + L('"'),
         ])
     ])
-    EncodingDecl = ''.join([
+    EncodingDecl = joinbytes([
         Sp,
         L('encoding'),
         Eq,
         oneof([
-            L("'") + '(?P<enc_dq>%s)' % name + L("'"),
-            L('"') + '(?P<enc_sq>%s)' % name + L('"')
+            L("'") + six.b('(?P<enc_dq>') + name + six.b(')') + L("'"),
+            L('"') + six.b('(?P<enc_sq>') + name + six.b(')') + L('"')
         ])
     ])
     # standalone="yes" is valid XML but almost certainly a lie...
-    SDDecl = ''.join([
+    SDDecl = joinbytes([
         Sp,
         L('standalone'),
         Eq,
             L('"') + oneof([L('yes'), L('no')]) + L('"'),
         ])
     ])
-    R = ''.join([
-        str(prefix) if PY3 else prefix,
+    R = joinbytes([
+        prefix,
         L('<?xml'),
         VersionInfo,
         optional(EncodingDecl),
         Ss,
         L('?>')
     ])
-    m = re.match(R, str(xml) if PY3 else xml)
+    m = re.match(R, xmlbytes)
     if m:
         encvalue = m.group('enc_dq')
         if encvalue is None:
             encvalue = m.group('enc_sq')
             if encvalue is None:
                 return enc
-        decl_enc = encvalue.decode(enc).encode('ascii')
+        decl_enc = encvalue.decode(enc).encode('ascii').decode('ascii')
         bom_codec = None
 
         def get_codec(encoding):
        identified an encoding, so we don't need to parse the <?xml...?>
        to extract the encoding in theory."""
     if not isinstance(xml, six.binary_type):
-        raise TypeError('Expected str/bytes, got %r' % type(xml))
+        raise TypeError('Expected %s, got %r' % (six.binary_type.__name__,
+                                                 type(xml)))
     # Warning: The UTF-32 codecs aren't present before Python 2.6...
     # See also http://bugs.python.org/issue1399
     enc = {
-        '\x00\x00\xFE\xFF': 'utf_32',  # UCS4 1234, utf_32_be with BOM
-        '\xFF\xFE\x00\x00': 'utf_32',  # UCS4 4321, utf_32_le with BOM
-        '\x00\x00\xFF\xFE': 'undefined',  # UCS4 2143 (rare, we give up)
-        '\xFE\xFF\x00\x00': 'undefined',  # UCS4 3412 (rare, we give up)
-        '\x00\x00\x00\x3C': 'UTF_32_BE',  # UCS4 1234 (no BOM)
-        '\x3C\x00\x00\x00': 'UTF_32_LE',  # UCS4 4321 (no BOM)
-        '\x00\x00\x3C\x00': 'undefined',  # UCS4 2143 (no BOM, we give up)
-        '\x00\x3C\x00\x00': 'undefined',  # UCS4 3412 (no BOM, we give up)
-        '\x00\x3C\x00\x3F': 'UTF_16_BE',  # missing BOM
-        '\x3C\x00\x3F\x00': 'UTF_16_LE',  # missing BOM
-        '\x3C\x3F\x78\x6D': 'ASCII',
-        '\x4C\x6F\xA7\x94': 'CP037',  # EBCDIC (unknown code page)
+        six.b('\x00\x00\xFE\xFF'): 'utf_32',  # UCS4 1234, utf_32_be with BOM
+        six.b('\xFF\xFE\x00\x00'): 'utf_32',  # UCS4 4321, utf_32_le with BOM
+        six.b('\x00\x00\xFF\xFE'): 'undefined', # UCS4 2143 (rare, give up)
+        six.b('\xFE\xFF\x00\x00'): 'undefined', # UCS4 3412 (rare, give up)
+        six.b('\x00\x00\x00\x3C'): 'UTF_32_BE', # UCS4 1234 (no BOM)
+        six.b('\x3C\x00\x00\x00'): 'UTF_32_LE', # UCS4 4321 (no BOM)
+        six.b('\x00\x00\x3C\x00'): 'undefined', # UCS4 2143 (no BOM, give up)
+        six.b('\x00\x3C\x00\x00'): 'undefined', # UCS4 3412 (no BOM, give up)
+        six.b('\x00\x3C\x00\x3F'): 'UTF_16_BE', # missing BOM
+        six.b('\x3C\x00\x3F\x00'): 'UTF_16_LE', # missing BOM
+        six.b('\x3C\x3F\x78\x6D'): 'ASCII',
+        six.b('\x4C\x6F\xA7\x94'): 'CP037',     # EBCDIC (unknown code page)
     }.get(xml[:4])
     if enc and enc == enc.lower():
         return enc
     if not enc:
-        if xml[:3] == '\xEF\xBB\xBF':
+        if xml[:3] == six.b('\xEF\xBB\xBF'):
             return 'utf_8_sig'  # UTF-8 with these three bytes prefixed
-        elif xml[:2] == '\xFF\xFE':
+        elif xml[:2] == six.b('\xFF\xFE'):
             return 'utf_16_le'
-        elif xml[:2] == '\xFE\xFF':
+        elif xml[:2] == six.b('\xFE\xFF'):
             return 'utf_16_be'
         else:
             enc = 'UTF-8'  # "Other"

File tests/test_operators.py

View file
 import strainer.operators as ops
-from strainer.xhtmlify import PY3
 from nose.tools import raises
 
 
         <div></div>
         </form>"""
     e = """<form action="" class="required tableform" method="post"><div /></form>"""
-    e = e.encode('ascii') if PY3 else e
     r = ops.normalize_to_xhtml(s)
     assert r == e, r
 
         <div></div>&nbsp;
         </form>"""
     e = """<form action="" class="required tableform" method="post"><div /></form>"""
-    e = e.encode('ascii') if PY3 else e
     r = ops.normalize_to_xhtml(s)
     assert r == e, r
 
     </body>
     </html>"""
     e = """<html><body><form action="" class="required tableform" method="post"><div /></form></body></html>"""
-    e = e.encode('ascii') if PY3 else e
     r = ops.normalize_to_xhtml(s)
     assert r == e, r
 

File tests/test_validate.py

View file
     try:
         validate_xhtml('<html/>', doctype=DOCTYPE_XHTML1_STRICT)
     except XHTMLSyntaxError as e:
-        emsg = e.args[0] if PY3 else e.message
+        emsg = e.args[0]
         assert 'line 1, column 8' in emsg, emsg
         assert 'Element html content does not follow the DTD' in emsg
         assert 'expecting (head, body)' in emsg.replace(' ,', ',')
     try:
         validate_xhtml_fragment('</p>')
     except XHTMLSyntaxError as e:
-        emsg = e.args[0] if PY3 else e.message
+        emsg = e.args[0]
         assert emsg == ('Opening and ending tag mismatch: '
                         'div line 0 and p, line 1, column 5'), emsg

File tests/test_xhtmlify.py

View file
 import re
 import encodings.aliases
 import codecs
+import six
+
 from strainer.xhtmlify import xhtmlify as _xhtmlify, xmlparse, ValidationError
 from strainer.xhtmlify import sniff_encoding, fix_xmldecl
 from strainer.doctypes import DOCTYPE_XHTML1_STRICT
-import six
 
 
 def xhtmlify(html, *args, **kwargs):
        and that it is idempotent (makes no changes when fed its output)."""
     _wrap = None
     if '_wrap' in kwargs:
-        _wrap = kwargs['_wrap']
-        del kwargs['_wrap']
+        _wrap = kwargs.pop('_wrap')
+    unicode_input = isinstance(html, six.text_type)
     xhtml = _xhtmlify(html, *args, **kwargs)
+    assert isinstance(xhtml, six.text_type) == unicode_input
+    regex_type = six.u if unicode_input else six.b
+    stripped_xhtml = None
     try:
         # ET can't handle <!...>
-        stripped_xhtml = re.sub(r'(?s)<!(?!\[).*?>', '', xhtml)
+        stripped_xhtml = re.sub(regex_type(r'(?s)<!(?!\[).*?>'), '', xhtml)
         xmlparse(stripped_xhtml, wrap=_wrap)
     except Exception as e:
         assert False, (stripped_xhtml, str(e))
             ''.encode(encoding)
         except LookupError:  # not trying to handle unknown encodings yet
             continue
-        xmldecl = fix_xmldecl(six.u('  <?xml>'), encoding, add_encoding=True)
+        xmldecl = fix_xmldecl(six.u('  <?xml>').encode(encoding), encoding,
+                              add_encoding=True)
         if encoding.lower().startswith('utf'):
             if '16' in encoding:
                 if 'le' in encoding.lower():
 
 def test_xhtmlify_handles_utf8_xmldecl():
     result = xhtmlify(six.u('<?xml><html>'), 'utf-8', _wrap=False)
+    assert result==six.u('<?xml version=\'1.0\'?><html xmlns="http://www.w3.org/1999/xhtml"></html>')
+    result = xhtmlify(six.u('<?xml><html>').encode('utf-8'), 'utf-8',
+                      _wrap=False)
     assert result.decode('utf-8')==six.u('<?xml version=\'1.0\'?><html xmlns="http://www.w3.org/1999/xhtml"></html>')
 
 def test_xhtmlify_handles_utf16_xmldecl():
     result = xhtmlify(six.u('<?xml><html>'), 'utf_16_be', _wrap=False)
+    assert result==six.u('<?xml version=\'1.0\'?><html xmlns="http://www.w3.org/1999/xhtml"></html>')
+    result = xhtmlify(six.u('<?xml><html>').encode('utf_16_be'), 'utf_16_be',
+                      _wrap=False)
     assert result.decode('utf16')==six.u('<?xml version=\'1.0\'?><html xmlns="http://www.w3.org/1999/xhtml"></html>')
 
 def test_doctype():