Matt Chaput avatar Matt Chaput committed 524dc1e

Fixed FST code to work with unicode strings and Py3 bytes objects.

Comments (0)

Files changed (8)

src/whoosh/codec/standard.py

         self.text = text
         self.terminfo = base.FileTermInfo()
         if self.spelling:
-            self.dawg.insert(utf8encode(text)[0])
+            self.dawg.insert_string(utf8encode(text)[0])
         self._start_blocklist()
 
     def add(self, docnum, weight, valuestring, length):
     def add_spell_word(self, fieldname, text):
         if self.dawg is None:
             self._make_dawg_files()
-        self.dawg.insert(utf8encode(text)[0])
+        self.dawg.insert_string(utf8encode(text)[0])
 
     def finish_term(self):
         if self.block is None:

src/whoosh/fields.py

 from whoosh.analysis import (IDAnalyzer, RegexAnalyzer, KeywordAnalyzer,
                              StandardAnalyzer, NgramAnalyzer, Tokenizer,
                              NgramWordAnalyzer, Analyzer)
-from whoosh.compat import (with_metaclass, itervalues, string_type, u,
+from whoosh.compat import (with_metaclass, itervalues, string_type, u, b,
                            integer_types, long_type, text_type, xrange, PY3)
 from whoosh.support.numeric import (int_to_text, text_to_int, long_to_text,
                                     text_to_long, float_to_text, text_to_float,
 # fields. There's no "out-of-band" value possible (except for floats, where we
 # use NaN), so we try to be conspicuous at least by using the maximum possible
 # value
-NaN = struct.unpack("<f", '\x00\x00\xc0\xff')[0]
+NaN = struct.unpack("<f", b('\x00\x00\xc0\xff'))[0]
 NUMERIC_DEFAULTS = {"b": 2 ** 7 - 1, "B": 2 ** 8 - 1, "h": 2 ** 15 - 1,
                     "H": 2 ** 16 - 1, "i": 2 ** 31 - 1, "I": 2 ** 32 - 1,
                     "q": 2 ** 63 - 1, "Q": 2 ** 64 - 1, "f": NaN,

src/whoosh/filedb/filewriting.py

         for fieldname in fieldnames:
             gw.start_field(fieldname)
             for word in r.lexicon(fieldname):
-                gw.insert(utf8encode(word)[0])
+                gw.insert_string(utf8encode(word)[0])
             gw.finish_field()
         gw.close()
 

src/whoosh/spelling.py

     for word in wordlist:
         if strip:
             word = word.strip()
-        gw.insert(utf8encode(word)[0])
+        gw.insert_string(word)
     gw.close()
 
 

src/whoosh/support/dawg.py

 from whoosh.compat import (b, BytesIO, xrange, iteritems, iterkeys, bytes_type,
                            izip)
 from whoosh.filedb.structfile import StructFile
-from whoosh.system import _INT_SIZE, pack_byte, pack_int, pack_uint, pack_long
+from whoosh.system import (_INT_SIZE, pack_byte, pack_ushort, pack_int,
+                           pack_uint, pack_long)
+from whoosh.util import utf8encode, utf8decode, varint
 
 
 class FileVersionError(Exception):
 
 emptybytes = b("")
 
+ARC_LAST = 1
+ARC_ACCEPT = 2
+ARC_STOP = 4
+ARC_HAS_VAL = 8
+ARC_HAS_ACCEPT_VAL = 16
+MULTIBYTE_LABEL = 32
+
 
 # FST Value types
 
 # Cursor
 
 class BaseCursor(object):
+    """Base class for cursor objects.
+    """
+
     def is_active(self):
+        """Returns True if this cursor is still active, that is it has not
+        read past the last arc in the graph.
+        """
+
         raise NotImplementedError
 
     def label(self):
+        """Returns the label bytes of the current arc.
+        """
+
         raise NotImplementedError
 
     def prefix(self):
+        """Returns a sequence of the label bytes for the path from the root
+        to the current arc.
+        """
+
         raise NotImplementedError
 
     def prefix_bytes(self):
+        """Returns the label bytes for the path from the root to the current
+        arc as a single joined bytes object.
+        """
+
         return emptybytes.join(self.prefix())
 
     def peek_key(self):
+        """Returns a sequence of label bytes representing the next closest
+        key in the graph.
+        """
+
         for label in self.prefix():
             yield label
         c = self.copy()
             yield c.label()
 
     def peek_key_bytes(self):
+        """Returns the next closest key in the graph as a single bytes object.
+        """
+
         return emptybytes.join(self.peek_key())
 
     def stopped(self):
+        """Returns True if the current arc leads to a stop state.
+        """
+
         raise NotImplementedError
 
     def value(self):
+        """Returns the value at the current arc, if reading an FST.
+        """
+
         raise NotImplementedError
 
     def accept(self):
+        """Returns True if the current arc leads to an accept state (the end
+        of a valid key).
+        """
+
         raise NotImplementedError
 
     def at_last_arc(self):
+        """Returns True if the current arc is the last outgoing arc from the
+        previous node.
+        """
+
         raise NotImplementedError
 
     def next_arc(self):
+        """Moves to the next outgoing arc from the previous node.
+        """
+
         raise NotImplementedError
 
     def follow(self):
+        """Follows the current arc.
+        """
+
         raise NotImplementedError
 
     def switch_to(self, label):
+        """Switch to the sibling arc with the given label bytes.
+        """
+
         _label = self.label
         _at_last_arc = self.at_last_arc
         _next_arc = self.next_arc
             _next_arc()
 
     def skip_to(self, key):
+        """Moves the cursor to the path represented by the given key bytes.
+        """
+
         _accept = self.accept
         _prefix = self.prefix
         _next_arc = self.next_arc
             _next_arc()
 
     def flatten(self):
+        """Yields the keys in the graph, starting at the current position.
+        """
+
         _is_active = self.is_active
         _accept = self.accept
         _stopped = self.stopped
             _next_arc()
 
     def flatten_v(self):
+        """Yields (key, value) tuples in an FST, starting at the current
+        position.
+        """
+
         for key in self.flatten():
             yield key, self.value()
 
     def find_path(self, path):
+        """Follows the labels in the given path, starting at the current
+        position.
+        """
+
         _switch_to = self.switch_to
         _follow = self.follow
         _stopped = self.stopped
             first = False
         return True
 
-    def follow_firsts(self):
-        while not self.stopped():
-            self.follow()
-
-    #    def follow_lasts(self):
-    #        while True:
-    #            while not self.stopped():
-    #                self.next_arc()
-    #            if self.current.target is not None:
-    #                self.follow()
-    #            else:
-    #                return
-
 
 class Cursor(BaseCursor):
+    """"A cursor-type object for navigating an FST/word graph, represented by a
+    :class:`GraphReader` object.
+    
+    >>> cur = GraphReader(dawgfile).cursor()
+    >>> for key in cur.follow():
+    ...   print(repr(key))
+    
+    The cursor "rests" on arcs in the FSA/FST graph, rather than nodes.
+    """
+
     def __init__(self, graph, root=None, stack=None):
         self.graph = graph
         self.vtype = graph.vtype
 #        return self
 
 
+class UncompiledNode(object):
+    # Represents an "in-memory" node used by the GraphWriter before it is
+    # written to disk.
+
+    compiled = False
+
+    def __init__(self, owner):
+        self.owner = owner
+        self._digest = None
+        self.clear()
+
+    def clear(self):
+        self.arcs = []
+        self.value = None
+        self.accept = False
+        self.inputcount = 0
+
+    def __repr__(self):
+        return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
+
+    def digest(self):
+        if self._digest is None:
+            d = sha1()
+            vtype = self.owner.vtype
+            for arc in self.arcs:
+                d.update(arc.label)
+                if arc.target:
+                    d.update(pack_long(arc.target))
+                else:
+                    d.update(b("z"))
+                if arc.value:
+                    d.update(vtype.to_bytes(arc.value))
+                if arc.accept:
+                    d.update(b("T"))
+            self._digest = d.digest()
+        return self._digest
+
+    def edges(self):
+        return self.arcs
+
+    def last_value(self, label):
+        assert self.arcs[-1].label == label
+        return self.arcs[-1].value
+
+    def add_arc(self, label, target):
+        self.arcs.append(Arc(label, target))
+
+    def replace_last(self, label, target, accept, acceptval=None):
+        arc = self.arcs[-1]
+        assert arc.label == label, "%r != %r" % (arc.label, label)
+        arc.target = target
+        arc.accept = accept
+        arc.acceptval = acceptval
+
+    def delete_last(self, label, target):
+        arc = self.arcs.pop()
+        assert arc.label == label
+        assert arc.target == target
+
+    def set_last_value(self, label, value):
+        arc = self.arcs[-1]
+        assert arc.label == label, "%r->%r" % (arc.label, label)
+        arc.value = value
+
+    def prepend_value(self, prefix):
+        add = self.owner.vtype.add
+        for arc in self.arcs:
+            arc.value = add(prefix, arc.value)
+        if self.accept:
+            self.value = add(prefix, self.value)
+
+
+class Arc(object):
+    """Represents a directed arc between two nodes in an FSA/FST graph.
+    
+    The ``lastarc`` attribute is True if this is the last outgoing arc from the
+    previous node.
+    """
+
+    __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
+                 "endpos")
+
+    def __init__(self, label=None, target=None, value=None, accept=False,
+                 acceptval=None):
+        """
+        :param label:The label bytes for this arc. For a word graph, this will
+            be a character.
+        :param target: The address of the node at the endpoint of this arc.
+        :param value: The inner FST value at the endpoint of this arc.
+        :param accept: Whether the endpoint of this arc is an accept state
+            (eg the end of a valid word).
+        :param acceptval: If the endpoint of this arc is an accept state, the
+            final FST value for that accepted state.
+        """
+
+        self.label = label
+        self.target = target
+        self.value = value
+        self.accept = accept
+        self.lastarc = None
+        self.acceptval = acceptval
+        self.endpos = None
+
+    def __repr__(self):
+        return "<%r-%s %s%s>" % (self.label, self.target,
+                                 "." if self.accept else "",
+                                 (" %r" % self.value) if self.value else "")
+
+    def __eq__(self, other):
+        if (isinstance(other, self.__class__) and self.accept == other.accept
+            and self.lastarc == other.lastarc and self.target == other.target
+            and self.value == other.value and self.label == other.label):
+            return True
+        return False
+
+
+# Graph writer
+
+class GraphWriter(object):
+    """Writes an FSA/FST graph to disk.
+    
+    Call ``insert_string(bytes)`` to insert keys into the graph. You must
+    insert keys in sorted order. Call ``close()`` to finish the graph and close
+    the file.
+    
+    >>> gw = GraphWriter(my_file)
+    >>> gw.insert_string("alfa")
+    >>> gw.insert_string("bravo")
+    >>> gw.insert_string("charlie")
+    >>> gw.close()
+    
+    The graph writer can write separate graphs for multiple fields. Use
+    ``start_field(name)`` and ``finish_field()`` to separate fields.
+    
+    >>> gw = GraphWriter(my_file)
+    >>> gw.start_field("content")
+    >>> gw.insert_u16("alfalfa")
+    >>> gw.insert_u16("apple")
+    >>> gw.finish_field()
+    >>> gw.start_field("title")
+    >>> gw.insert_u16("artichoke")
+    >>> gw.finish_field()
+    >>> gw.close()
+    """
+
+    version = 1
+
+    def __init__(self, dbfile, vtype=None, merge=None):
+        """
+        :param dbfile: the file to write to.
+        :param vtype: a :class:`Values` class to use for storing values. This
+            is only necessary if you will be storing values for the keys.
+        :param merge: a function that takes two values and returns a single
+            value. This is called if you insert two identical keys with values.
+        """
+
+        self.dbfile = dbfile
+        self.vtype = vtype
+        self.merge = merge
+        self.fieldroots = {}
+        self.arc_count = 0
+        self.node_count = 0
+        self.fixed_count = 0
+
+        dbfile.write(b("GRPH"))
+        dbfile.write_int(self.version)
+        dbfile.write_uint(0)
+
+        self.fieldname = None
+        self.start_field("_")
+
+    def start_field(self, fieldname):
+        """Starts a new graph for the given field.
+        """
+
+        if not fieldname:
+            raise ValueError("Field name cannot be equivalent to False")
+        if self.fieldname is not None:
+            self.finish_field()
+        self.fieldname = fieldname
+        self.seen = {}
+        self.nodes = [UncompiledNode(self)]
+        self.lastkey = ''
+        self._inserted = False
+
+    def finish_field(self):
+        """Finishes the graph for the current field.
+        """
+
+        if self._inserted:
+            self.fieldroots[self.fieldname] = self._finish()
+        self.fieldname = None
+
+    def close(self):
+        """Finishes the current graph and closes the underlying file.
+        """
+
+        if self.fieldname is not None:
+            self.finish_field()
+        dbfile = self.dbfile
+        here = dbfile.tell()
+        dbfile.write_pickle(self.fieldroots)
+        dbfile.flush()
+        dbfile.seek(4 + _INT_SIZE)  # Seek past magic and version number
+        dbfile.write_uint(here)
+        dbfile.close()
+
+    def insert(self, key, value=None):
+        """Inserts the given sequence of bytestrings as a key.
+        
+        This will work with Python 2 ``str`` objects but WON'T work with Python
+        3 ``bytes`` objects because they act like sequences of numbers, not
+        bytestrings. For consistency, instead use ``insert_string()`` which
+        accepts both bytes and unicode.
+        
+        :param key: a sequence of bytestrings.
+        :param value: an optional value to encode in the graph along with the
+            key. If the writer was not instantiated with a value type, passing
+            a value here will raise an error.
+        """
+
+        if self.fieldname is None:
+            raise Exception("Inserted %r before starting a field" % key)
+        self._inserted = True
+
+        vtype = self.vtype
+        lastkey = self.lastkey
+        nodes = self.nodes
+        if len(key) < 1:
+            raise KeyError("Can't store a null key %r" % key)
+        if lastkey and lastkey > key:
+            raise KeyError("Keys out of order %r..%r" % (lastkey, key))
+
+        # Find the common prefix shared by this key and the previous one
+        prefixlen = 0
+        for i in xrange(min(len(lastkey), len(key))):
+            if lastkey[i] != key[i]:
+                break
+            prefixlen += 1
+        # Compile the nodes after the prefix, since they're not shared
+        self._freeze_tail(prefixlen + 1)
+
+        # Create new nodes for the parts of this key after the shared prefix
+        for char in key[prefixlen:]:
+            node = UncompiledNode(self)
+            # Create an arc to this node on the previous node
+            nodes[-1].add_arc(char, node)
+            nodes.append(node)
+        # Mark the last node as an accept state
+        lastnode = nodes[-1]
+        lastnode.accept = True
+
+        if vtype:
+            if value is not None and not vtype.is_valid(value):
+                raise ValueError("%r is not valid for %s" % (value, vtype))
+
+            # Push value commonalities through the tree
+            common = None
+            for i in xrange(1, prefixlen + 1):
+                node = nodes[i]
+                parent = nodes[i - 1]
+                lastvalue = parent.last_value(key[i - 1])
+                if lastvalue is not None:
+                    common = vtype.common(value, lastvalue)
+                    suffix = vtype.subtract(lastvalue, common)
+                    parent.set_last_value(key[i - 1], common)
+                    node.prepend_value(suffix)
+                else:
+                    common = suffix = None
+                value = vtype.subtract(value, common)
+
+            if key == lastkey:
+                # If this key is a duplicate, merge its value with the value of
+                # the previous (same) key
+                lastnode.value = self.merge(lastnode.value, value)
+            else:
+                nodes[prefixlen].set_last_value(key[prefixlen], value)
+        elif value:
+            raise Exception("Value %r but no value type" % value)
+
+        self.lastkey = key
+
+    def insert_string(self, key, value=None):
+        """This method converts the given ``key`` string into a sequence of
+        UTF-8 encoded bytestrings for each character and passes it to the
+        ``insert()`` method. It should work with bytes and all string
+        representations.
+        """
+
+        # I hate the Python 3 bytes object so friggin much
+        if isinstance(key, bytes_type):
+            k = [key[i:i + 1] for i in xrange(len(key))]
+        else:
+            k = [utf8encode(key[i:i + 1])[0] for i in xrange(len(key))]
+        self.insert(k, value=value)
+
+    def _freeze_tail(self, prefixlen):
+        nodes = self.nodes
+        lastkey = self.lastkey
+        downto = max(1, prefixlen)
+
+        while len(nodes) > downto:
+            node = nodes.pop()
+            parent = nodes[-1]
+            inlabel = lastkey[len(nodes) - 1]
+
+            self._compile_targets(node)
+            accept = node.accept or len(node.arcs) == 0
+            address = self._compile_node(node)
+            parent.replace_last(inlabel, address, accept, node.value)
+
+    def _finish(self):
+        nodes = self.nodes
+        root = nodes[0]
+        # Minimize nodes in the last word's suffix
+        self._freeze_tail(0)
+        # Compile remaining targets
+        self._compile_targets(root)
+        return self._compile_node(root)
+
+    def _compile_targets(self, node):
+        for arc in node.arcs:
+            if isinstance(arc.target, UncompiledNode):
+                n = arc.target
+                if len(n.arcs) == 0:
+                    arc.accept = n.accept = True
+                arc.target = self._compile_node(n)
+
+    def _compile_node(self, uncnode):
+        seen = self.seen
+
+        if len(uncnode.arcs) == 0:
+            # Leaf node
+            address = self._write_node(uncnode)
+        else:
+            d = uncnode.digest()
+            address = seen.get(d)
+            if address is None:
+                address = self._write_node(uncnode)
+                seen[d] = address
+        return address
+
+    def _write_node(self, uncnode):
+        vtype = self.vtype
+        dbfile = self.dbfile
+        arcs = uncnode.arcs
+        numarcs = len(arcs)
+
+        if not numarcs:
+            if uncnode.accept:
+                return None
+            else:
+                # What does it mean for an arc to stop but not be accepted?
+                raise Exception
+        self.node_count += 1
+
+        buf = StructFile(BytesIO())
+        nodestart = dbfile.tell()
+        #self.count += 1
+        #self.arccount += numarcs
+
+        fixedsize = -1
+        arcstart = buf.tell()
+        for i, arc in enumerate(arcs):
+            self.arc_count += 1
+            target = arc.target
+            label = arc.label
+
+            flags = 0
+            if len(label) > 1:
+                flags += MULTIBYTE_LABEL
+            if i == numarcs - 1:
+                flags += ARC_LAST
+            if arc.accept:
+                flags += ARC_ACCEPT
+            if target is None:
+                flags += ARC_STOP
+            if arc.value is not None:
+                flags += ARC_HAS_VAL
+            if arc.acceptval is not None:
+                flags += ARC_HAS_ACCEPT_VAL
+
+            buf.write(pack_byte(flags))
+            if len(label) > 1:
+                buf.write(varint(len(label)))
+            buf.write(label)
+            if target is not None:
+                buf.write(pack_uint(target))
+            if arc.value is not None:
+                vtype.write(buf, arc.value)
+            if arc.acceptval is not None:
+                vtype.write(buf, arc.acceptval)
+
+            here = buf.tell()
+            thissize = here - arcstart
+            arcstart = here
+            if fixedsize == -1:
+                fixedsize = thissize
+            elif fixedsize > 0 and thissize != fixedsize:
+                fixedsize = 0
+
+        if fixedsize > 0:
+            # Write a fake arc containing the fixed size and number of arcs
+            dbfile.write_byte(255)  # FIXED_SIZE
+            dbfile.write_int(fixedsize)
+            dbfile.write_int(numarcs)
+            self.fixed_count += 1
+        dbfile.write(buf.file.getvalue())
+
+        return nodestart
+
+
 # Graph reader
 
 class BaseGraphReader(object):
 
 
 class GraphReader(BaseGraphReader):
-    def __init__(self, dbfile, rootname=None, labelsize=1, vtype=None,
-                 filebase=0):
+    def __init__(self, dbfile, rootname=None, vtype=None, filebase=0):
         self.dbfile = dbfile
-        self.labelsize = labelsize
         self.vtype = vtype
         self.filebase = filebase
 
         dbfile = self.dbfile
         flags = dbfile.read_byte()
         if flags == 255:
-            # FIXED_SIZE
+            # This is a fake arc containing fixed size information; skip it
+            # and read the next arc
             dbfile.seek(_INT_SIZE * 2, 1)
             flags = dbfile.read_byte()
-        toarc.label = dbfile.read(self.labelsize)
+        toarc.label = self._read_label(flags)
         return self._read_arc_data(flags, toarc)
 
+    def _read_label(self, flags):
+        dbfile = self.dbfile
+        if flags & MULTIBYTE_LABEL:
+            length = dbfile.read_varint()
+        else:
+            length = 1
+        label = dbfile.read(length)
+        return label
+
     def _read_fixed_info(self):
         dbfile = self.dbfile
 
 
     def _read_arc_data(self, flags, arc):
         dbfile = self.dbfile
-        accept = arc.accept = bool(flags & 2)
-        arc.lastarc = flags & 1
-        if flags & 4:  # STOP_NODE
+        accept = arc.accept = bool(flags & ARC_ACCEPT)
+        arc.lastarc = flags & ARC_LAST
+        if flags & ARC_STOP:
             arc.target = None
         else:
             arc.target = dbfile.read_uint()
-        if flags & 8:  # ARC_HAS_VALUE
+        if flags & ARC_HAS_VAL:
             arc.value = self.vtype.read(dbfile)
         else:
             arc.value = None
-        if accept and flags & 16:  # ARC_HAS_ACCEPT_VALUE
+        if accept and flags & ARC_HAS_ACCEPT_VAL:
             arc.acceptval = self.vtype.read(dbfile)
         arc.endpos = dbfile.tell()
         return arc
 
     def _binary_search(self, address, size, count, label, arc):
         dbfile = self.dbfile
-        labelsize = self.labelsize
+        _read_label = self._read_label
 
         lo = 0
         hi = count
             midaddr = address + mid * size
             dbfile.seek(midaddr)
             flags = dbfile.read_byte()
-            midlabel = dbfile.read(labelsize)
+            midlabel = self._read_label(flags)
             if midlabel == label:
                 arc.label = midlabel
                 return self._read_arc_data(flags, arc)
                                       sofar + char2 + char, arc.accept))
 
 
-# Graph writer
-
-class UncompiledNode(object):
-    compiled = False
-
-    def __init__(self, owner):
-        self.owner = owner
-        self.clear()
-
-    def clear(self):
-        self.arcs = []
-        self.value = None
-        self.accept = False
-        self.inputcount = 0
-
-    def __repr__(self):
-        return "<%r>" % ([(a.label, a.value) for a in self.arcs],)
-
-    def digest(self):
-        d = sha1()
-        vtype = self.owner.vtype
-        for arc in self.arcs:
-            d.update(arc.label)
-            if arc.target:
-                d.update(pack_long(arc.target))
-            else:
-                d.update("z")
-            if arc.value:
-                d.update(vtype.to_bytes(arc.value))
-            if arc.accept:
-                d.update(b("T"))
-        return d.digest()
-
-    def edges(self):
-        return self.arcs
-
-    def last_value(self, label):
-        assert self.arcs[-1].label == label
-        return self.arcs[-1].value
-
-    def add_arc(self, label, target):
-        self.arcs.append(Arc(label, target))
-
-    def replace_last(self, label, target, accept, acceptval=None):
-        arc = self.arcs[-1]
-        assert arc.label == label, "%r != %r" % (arc.label, label)
-        arc.target = target
-        arc.accept = accept
-        arc.acceptval = acceptval
-
-    def delete_last(self, label, target):
-        arc = self.arcs.pop()
-        assert arc.label == label
-        assert arc.target == target
-
-    def set_last_value(self, label, value):
-        arc = self.arcs[-1]
-        assert arc.label == label, "%r->%r" % (arc.label, label)
-        arc.value = value
-
-    def prepend_value(self, prefix):
-        add = self.owner.vtype.add
-        for arc in self.arcs:
-            arc.value = add(prefix, arc.value)
-        if self.accept:
-            self.value = add(prefix, self.value)
-
-
-class Arc(object):
-    __slots__ = ("label", "target", "accept", "value", "lastarc", "acceptval",
-                 "endpos")
-
-    def __init__(self, label=None, target=None, value=None, accept=False,
-                 acceptval=None):
-        self.label = label
-        self.target = target
-        self.value = value
-        self.accept = accept
-        self.lastarc = None
-        self.acceptval = acceptval
-        self.endpos = None
-
-    def __repr__(self):
-        return "<%r-%s %s%s>" % (self.label, self.target,
-                                 "." if self.accept else "",
-                                 (" %r" % self.value) if self.value else "")
-
-    def __eq__(self, other):
-        if (isinstance(other, self.__class__) and self.accept == other.accept
-            and self.lastarc == other.lastarc and self.target == other.target
-            and self.value == other.value and self.label == other.label):
-            return True
-        return False
-
-
-class GraphWriter(object):
-    version = 1
-
-    def __init__(self, dbfile, vtype=None, merge=None):
-        """
-        :param dbfile: the file to write to.
-        :param vtype: a :class:`Values` class to use for storing values. This
-            is only necessary if you will be storing values for the keys.
-        :param merge: a function that takes two values and returns a single
-            value. This is called if you insert two identical keys with values.
-        """
-
-        self.dbfile = dbfile
-        self.vtype = vtype
-        self.merge = merge
-        self.fieldroots = {}
-
-        dbfile.write(b("GRPH"))
-        dbfile.write_int(self.version)
-        dbfile.write_uint(0)
-
-        self.fieldname = None
-        self.start_field("_")
-
-    def start_field(self, fieldname):
-        if not fieldname:
-            raise ValueError("Field name cannot be equivalent to False")
-        if self.fieldname is not None:
-            self.finish_field()
-        self.fieldname = fieldname
-        self.seen = {}
-        self.nodes = [UncompiledNode(self)]
-        self.lastkey = ''
-        self._inserted = False
-
-    def finish_field(self):
-        if self._inserted:
-            self.fieldroots[self.fieldname] = self._finish()
-        self.fieldname = None
-
-    def close(self):
-        if self.fieldname is not None:
-            self.finish_field()
-        dbfile = self.dbfile
-        here = dbfile.tell()
-        dbfile.write_pickle(self.fieldroots)
-        dbfile.flush()
-        dbfile.seek(4 + _INT_SIZE)  # Seek past magic and version number
-        dbfile.write_uint(here)
-        dbfile.close()
-
-    def insert(self, key, value=None):
-        if self.fieldname is None:
-            raise Exception("Inserted %r before starting a field" % key)
-        self._inserted = True
-
-        vtype = self.vtype
-        lastkey = self.lastkey
-        nodes = self.nodes
-        if len(key) < 1:
-            raise KeyError("Can't store a null key %r" % key)
-        if self.lastkey > key:
-            raise KeyError("Keys out of order %r..%r" % (self.lastkey, key))
-
-        # Find the common prefix shared by this key and the previous one
-        prefixlen = 0
-        for i in xrange(min(len(lastkey), len(key))):
-            if lastkey[i] != key[i]:
-                break
-            prefixlen += 1
-        # Compile the nodes after the prefix, since they're not shared
-        self._freeze_tail(prefixlen + 1)
-
-        # Create new nodes for the parts of this key after the shared prefix
-        for char in key[prefixlen:]:
-            node = UncompiledNode(self)
-            # Create an arc to this node on the previous node
-            nodes[-1].add_arc(char, node)
-            nodes.append(node)
-        # Mark the last node as an accept state
-        lastnode = nodes[-1]
-        lastnode.accept = True
-
-        if vtype:
-            if value is not None and not vtype.is_valid(value):
-                raise ValueError("%r is not valid for %s" % (value, vtype))
-
-            # Push value commonalities through the tree
-            common = None
-            for i in xrange(1, prefixlen + 1):
-                node = nodes[i]
-                parent = nodes[i - 1]
-                lastvalue = parent.last_value(key[i - 1])
-                if lastvalue is not None:
-                    common = vtype.common(value, lastvalue)
-                    suffix = vtype.subtract(lastvalue, common)
-                    parent.set_last_value(key[i - 1], common)
-                    node.prepend_value(suffix)
-                else:
-                    common = suffix = None
-                value = vtype.subtract(value, common)
-
-            if key == lastkey:
-                # If this key is a duplicate, merge its value with the value of
-                # the previous (same) key
-                lastnode.value = self.merge(lastnode.value, value)
-            else:
-                nodes[prefixlen].set_last_value(key[prefixlen], value)
-        elif value:
-            raise Exception("Value %r but no value type" % value)
-
-        self.lastkey = key
-
-    def _freeze_tail(self, prefixlen):
-        nodes = self.nodes
-        lastkey = self.lastkey
-        downto = max(1, prefixlen)
-
-        while len(nodes) > downto:
-            node = nodes.pop()
-            parent = nodes[-1]
-            inlabel = lastkey[len(nodes) - 1]
-
-            self._compile_targets(node)
-            accept = node.accept or len(node.arcs) == 0
-            address = self._compile_node(node)
-            parent.replace_last(inlabel, address, accept, node.value)
-
-    def _finish(self):
-        nodes = self.nodes
-        root = nodes[0]
-        # Minimize nodes in the last word's suffix
-        self._freeze_tail(0)
-        # Compile remaining targets
-        self._compile_targets(root)
-        return self._compile_node(root)
-
-    def _compile_targets(self, node):
-        for arc in node.arcs:
-            if isinstance(arc.target, UncompiledNode):
-                n = arc.target
-                if len(n.arcs) == 0:
-                    arc.accept = n.accept = True
-                arc.target = self._compile_node(n)
-
-    def _compile_node(self, uncnode):
-        seen = self.seen
-
-        if len(uncnode.arcs) == 0:
-            # Leaf node
-            address = self._write_node(uncnode)
-        else:
-            d = uncnode.digest()
-            address = seen.get(d)
-            if address is None:
-                address = self._write_node(uncnode)
-                seen[d] = address
-        return address
-
-    def _write_node(self, uncnode):
-        vtype = self.vtype
-        dbfile = self.dbfile
-        arcs = uncnode.arcs
-        numarcs = len(arcs)
-
-        if not numarcs:
-            if uncnode.accept:
-                return None
-            else:
-                # What does it mean for an arc to stop but not be final?
-                raise Exception
-
-        buf = StructFile(BytesIO())
-        nodestart = dbfile.tell()
-        #self.count += 1
-        #self.arccount += numarcs
-
-        fixedsize = -1
-        arcstart = buf.tell()
-        for i, arc in enumerate(arcs):
-            target = arc.target
-
-            flags = 0
-            if i == numarcs - 1:
-                flags += 1  # LAST_ARC
-            if arc.accept:
-                flags += 2    # FINAL_ARC
-            if target is None:
-                # Target has no arcs
-                flags += 4  # STOP_NODE
-            if arc.value is not None:
-                flags += 8  # ARC_HAS_VALUE
-            if arc.acceptval is not None:
-                flags += 16  # ARC_HAS_ACCEPT_VAL
-
-            buf.write(pack_byte(flags))
-            buf.write(arc.label)
-            if target >= 0:
-                buf.write(pack_uint(target))
-            if arc.value is not None:
-                vtype.write(buf, arc.value)
-            if arc.acceptval is not None:
-                vtype.write(buf, arc.acceptval)
-
-            here = buf.tell()
-            thissize = here - arcstart
-            arcstart = here
-            if fixedsize == -1:
-                fixedsize = thissize
-            elif fixedsize > 0 and thissize != fixedsize:
-                fixedsize = 0
-
-        if fixedsize > 0:
-            # Write a fake arc containing the fixed size and number of arcs
-            dbfile.write_byte(255)  # FIXED_SIZE
-            dbfile.write_int(fixedsize)
-            dbfile.write_int(numarcs)
-        dbfile.write(buf.file.getvalue())
-
-        return nodestart
-
-
 # Utility functions
 
 def dump_graph(graph, address=None, tab=0, out=None):

src/whoosh/util.py

 # Note: these functions return a tuple of (text, length), so when you call
 # them, you have to add [0] on the end, e.g. str = utf8encode(unicode)[0]
 
-utf8encode = codecs.getencoder("utf_8")
-utf8decode = codecs.getdecoder("utf_8")
+utf8encode = codecs.getencoder("utf-8")
+utf8decode = codecs.getdecoder("utf-8")
+
+utf16encode = codecs.getencoder("utf-16-be")
+utf16decode = codecs.getdecoder("utf-16-be")
+
+utf32encode = codecs.getencoder("utf-32-be")
+utf32decode = codecs.getdecoder("utf-32-be")
 
 
 # Functions

tests/test_dawg.py

     f = st.create_file("test")
     gw = dawg.GraphWriter(f)
     for key in keys:
-        gw.insert(key)
+        gw.insert_string(key)
     gw.close()
     return st
 
 def test_keys_out_of_order():
     f = RamStorage().create_file("test")
     gw = dawg.GraphWriter(f)
-    gw.insert(b("alfa"))
-    assert_raises(KeyError, gw.insert, b("abba"))
+    gw.insert_string("alfa")
+    assert_raises(KeyError, gw.insert_string, "abba")
 
 def test_duplicate_keys():
     st = gwrite(enlist("alfa bravo bravo bravo charlie"))
     assert_raises(dawg.InactiveCursor, list, cur.flatten())
     assert_raises(dawg.InactiveCursor, list, cur.flatten_v())
     assert_raises(dawg.InactiveCursor, cur.find_path, b("a"))
-    assert_raises(dawg.InactiveCursor, cur.follow_firsts)
 
 def test_types():
     st = RamStorage()
         f = st.create_file("test")
         gw = dawg.GraphWriter(f, vtype=t)
         for key, value in domain:
-            gw.insert(key, value)
+            gw.insert_string(key, value)
         gw.close()
 
         f = st.open_file("test")
               ]
     _fst_roundtrip(domain, dawg.IntListValues)
 
-#def test_fst_merge():
-#    # 2; 3; 5; 7; 11; 13; 17; 19
-#    ins = [(b("000"), 2), (b("000"), 2), (b("001"), 3), (b("010"), 5),
-#           (b("010"), 5), (b("011"), 7), (b("100"), 11), (b("101"), 13),
-#           (b("101"), 13), (b("110"), 17), (b("111"), 19), (b("111"), 19)]
-#    outs = [(b("000"), 4), (b("001"), 3), (b("010"), 10), (b("011"), 7),
-#            (b("100"), 11), (b("101"), 26), (b("110"), 17), (b("111"), 38)]
-#
-#    with TempStorage() as st:
-#        f = st.create_file("test")
-#        gw = dawg.GraphWriter(f, vtype=dawg.IntValues,
-#                              merge=lambda v1, v2: v1 + v2)
-#        for key, value in ins:
-#            gw.insert(key, value)
-#        gw.close()
-#
-#        f = st.open_file("test")
-
 def test_words():
     words = enlist("alfa alpaca amtrak bellow fellow fiona zebulon")
     with TempStorage() as st:
         f = st.create_file("test")
         gw = dawg.GraphWriter(f)
         gw.start_field("f1")
-        gw.insert(b("a"))
-        gw.insert(b("aa"))
-        gw.insert(b("ab"))
+        gw.insert_string("a")
+        gw.insert_string("aa")
+        gw.insert_string("ab")
         gw.finish_field()
         gw.start_field("f2")
-        gw.insert(b("ba"))
-        gw.insert(b("baa"))
-        gw.insert(b("bab"))
+        gw.insert_string("ba")
+        gw.insert_string("baa")
+        gw.insert_string("bab")
         gw.close()
 
         gr = dawg.GraphReader(st.open_file("test"))
     st = gwrite(enlist("abcd abfg cdqr1 cdqr12 cdxy wxyz"))
     gr = greader(st)
     cur = gr.cursor()
-    cur.follow_firsts()
+    while not cur.stopped(): cur.follow()
     assert_equal(cur.prefix_bytes(), b("abcd"))
     assert cur.accept()
     cur._pop_to_prefix("abzz")
     assert_equal(cur.prefix_bytes(), b("abf"))
 
     cur = gr.cursor()
-    cur.follow_firsts()
+    while not cur.stopped(): cur.follow()
     assert_equal(cur.prefix_bytes(), b("abcd"))
     cur.skip_to(b("cdaa"))
     assert_equal(cur.peek_key_bytes(), b("cdqr1"))
     assert_equal(cur.prefix_bytes(), b("cdq"))
 
     cur = gr.cursor()
-    cur.follow_firsts()
+    while not cur.stopped(): cur.follow()
     cur.skip_to(b("z"))
     assert not cur.is_active()
 
+def test_insert_bytes():
+    # This test is only meaningful on Python 3
+    domain = [b("alfa"), b("bravo"), b("charlie")]
 
+    st = RamStorage()
+    gw = dawg.GraphWriter(st.create_file("test"))
+    for key in domain:
+        gw.insert_string(key)
+    gw.close()
 
+    cur = dawg.GraphReader(st.open_file("test")).cursor()
+    assert_equal(list(cur.flatten()), domain)
 
+def test_insert_unicode():
+    domain = [u("\u280b\u2817\u2801\u281d\u2809\u2811"),
+              u("\u65e5\u672c"),
+              u("\uc774\uc124\ud76c"),
+              ]
 
+    st = RamStorage()
+    gw = dawg.GraphWriter(st.create_file("test"))
+    for key in domain:
+        gw.insert_string(key)
+    gw.close()
 
+    cur = dawg.GraphReader(st.open_file("test")).cursor()
+    assert_equal(list(cur.flatten()), domain)
 
+def test_within_unicode():
+    domain = [u("\u280b\u2817\u2801\u281d\u2809\u2811"),
+              u("\u65e5\u672c"),
+              u("\uc774\uc124\ud76c"),
+              ]
 
+    st = RamStorage()
+    gw = dawg.GraphWriter(st.create_file("test"))
+    for key in domain:
+        gw.insert_string(key)
+    gw.close()
+
+    gr = dawg.GraphReader(st.open_file("test"))
+    s = list(dawg.within(gr, u("\uc774.\ud76c")))
+    assert_equal(s, [u("\uc774\uc124\ud76c")])
+
+
+
+
+
+
+

tests/test_spelling.py

         path = fname
     else:
         return
-
     if not os.path.exists(path):
         return
-    wordfile = gzip.open(path, "r")
+
+    wordfile = gzip.open(path, "rb")
     cor = words_to_corrector(wordfile)
     wordfile.close()
     assert_equal(cor.suggest("specail"), ["special"])
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.