Commits

Bastian Blank  committed 1481eb8

Tree - Make sure that only unicode is written if no encoding is requested

  • Participants
  • Parent commits f7db63b

Comments (0)

Files changed (1)

File emeraldtree/tree.py

         self.encoding = encoding
         self.namespaces = namespaces
 
-    def _encode(self, text):
-        if self.encoding:
-            return text.encode(self.encoding, "xmlcharrefreplace")
-        return text
-
     def _escape_cdata(self, text):
         # escape character data
         # it's worth avoiding do-nothing calls for strings that are
             text = text.replace("<", "&lt;")
         if ">" in text:
             text = text.replace(">", "&gt;")
-        return self._encode(text)
+        return text
 
     def _escape_attrib(self, text):
         # escape attribute value
 
     def write(self, write, element):
         qnames, namespaces = self._namespaces(element)
-        self.serialize_start(write)
-        self.serialize(write, element, qnames, namespaces)
+
+        if self.encoding:
+            def write_encode(text):
+                write(text.encode(self.encoding, "xmlcharrefreplace"))
+        else:
+            write_encode = write
+
+        self.serialize_start(write_encode)
+        self.serialize(write_encode, element, qnames, namespaces)
 
 
 class TextWriter(BaseWriter):
     def serialize(self, write, elem, qnames=None, namespaces=None):
         for part in elem.itertext():
-            write(self._encode(part))
+            write(part)
 
 
 class XMLWriter(BaseWriter):
             tag = qnames[elem.tag]
 
             if tag is not None:
-                write("<" + tag)
+                write(u"<" + tag)
 
                 if elem.attrib:
                     items = elem.attrib.items()
                             v = qnames[v]
                         else:
                             v = self._escape_attrib(unicode(v))
-                        write(' ' + k + '="' + v + '"')
+                        write(u' ' + k + u'="' + v + u'"')
                 if namespaces:
                     items = namespaces.items()
                     items.sort(key=lambda x: x[1]) # sort on prefix
                     for v, k in items:
                         if k:
-                            k = ":" + k
-                        write(" xmlns%s=\"%s\"" % (
-                            self._encode(k),
+                            k = u":" + k
+                        write(u" xmlns%s=\"%s\"" % (
+                            k,
                             self._escape_attrib(v)
                             ))
                 if len(elem):
-                    write(">")
+                    write(u">")
                     for e in elem:
                         self.serialize(write, e, qnames)
-                    write("</" + tag + ">")
+                    write(u"</" + tag + u">")
                 else:
-                    write(" />")
+                    write(u" />")
 
             else:
                 for e in elem:
                     self.serialize(write, e, encoding, qnames)
 
         elif isinstance(elem, Comment):
-            write("<!--%s-->" % self._escape_cdata(elem.text))
+            write(u"<!--%s-->" % self._escape_cdata(elem.text))
 
         elif isinstance(elem, ProcessingInstruction):
             text = self._escape_cdata(elem.target)
             if elem.text is not None:
                 text += ' ' + self._escape_cdata(elem.text)
-            write("<?%s?>" % text)
+            write(u"<?%s?>" % text)
 
         else:
             write(self._escape_cdata(unicode(elem)))
 
     def serialize_start(self, write):
         if self.encoding and self.encoding not in ("utf-8", "us-ascii"):
-            write("<?xml version='1.0' encoding='%s'?>\n" % self.encoding)
+            write(u"<?xml version='1.0' encoding='%s'?>\n" % self.encoding)
 
 
 class HTMLWriter(BaseWriter):
             tag = qnames[elem.tag]
 
             if tag is not None:
-                write("<" + tag)
+                write(u"<" + tag)
 
                 if elem.attrib:
                     items = elem.attrib.items()
                         else:
                             v = self._escape_attrib(unicode(v))
                         # FIXME: handle boolean attributes
-                        write(' ' + k + '="' + v + '"')
+                        write(u' ' + k + u'="' + v + u'"')
                 if namespaces:
                     items = namespaces.items()
                     items.sort(key=lambda x: x[1]) # sort on prefix
                     for v, k in items:
                         if k:
-                            k = ":" + k
-                        write(" xmlns%s=\"%s\"" % (
-                            self._encode(k),
+                            k = u":" + k
+                        write(u" xmlns%s=\"%s\"" % (
+                            k,
                             self._escape_attrib(v)
                             ))
-                write(">")
+                write(u">")
 
                 if tag.lower() in ('script', 'style'):
-                    write(self._encode(''.join(elem.itertext())))
+                    write(u''.join(elem.itertext()))
                 else:
                     for e in elem:
                         self.serialize(write, e, qnames)
 
                 if tag not in self.empty_elements:
-                    write("</" + tag + ">")
+                    write(u"</" + tag + u">")
 
             else:
                 for e in elem:
                     self.serialize(write, e, qnames)
 
         elif isinstance(elem, Comment):
-            write("<!--%s-->" % self._escape_cdata(elem.text))
+            write(u"<!--%s-->" % self._escape_cdata(elem.text))
 
         elif isinstance(elem, ProcessingInstruction):
             text = self._escape_cdata(elem.target)
             if elem.text is not None:
                 text += ' ' + self._escape_cdata(elem.text)
-            write("<?%s?>" % text)
+            write(u"<?%s?>" % text)
 
         else:
-            write(self._escape_cdata(elem))
+            write(self._escape_cdata(unicode(elem)))