Commits

Matt Chaput  committed c33e697

Experimented with intersection meta-cursor.
Added GraphReader.find_path() and edited within() to use it.
Converted Values methods to staticmethods.

  • Participants
  • Parent commits be46cfc

Comments (0)

Files changed (2)

File src/whoosh/support/dawg.py

 from array import array
 from hashlib import sha1  #@UnresolvedImport
 
-from whoosh.compat import b, BytesIO, xrange, iteritems, iterkeys, bytes_type
+from whoosh.compat import (b, BytesIO, xrange, iteritems, iterkeys, bytes_type,
+                           izip)
 from whoosh.filedb.structfile import StructFile
 from whoosh.system import _INT_SIZE, pack_byte, pack_int, pack_uint, pack_long
 
 class FileVersionError(Exception):
     pass
 
+
+class InactiveCursor(Exception):
+    pass
+
+
 emptybytes = b("")
 
 
 # FST Value types
 
 class Values(object):
-    @classmethod
-    def is_valid(cls, v):
+    """Base for classes the describe how to encode and decode FST values.
+    """
+
+    @staticmethod
+    def is_valid(v):
         """Returns True if v is a valid object that can be stored by this
         class.
         """
 
         raise NotImplementedError
 
-    @classmethod
-    def common(cls, v1, v2):
+    @staticmethod
+    def common(v1, v2):
         """Returns the "common" part of the two values, for whatever "common"
         means for this class. For example, a string implementation would return
         the common shared prefix, for an int implementation it would return
 
         raise NotImplementedError
 
-    @classmethod
-    def add(cls, prefix, v):
+    @staticmethod
+    def add(prefix, v):
         """Adds the given prefix (the result of a call to common()) to the
         given value.
         """
 
         raise NotImplementedError
 
-    @classmethod
-    def subtract(cls, v, prefix):
+    @staticmethod
+    def subtract(v, prefix):
         """Subtracts the "common" part (the prefix) from the given value.
         """
 
         raise NotImplementedError
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         """Writes value v to a file.
         """
 
         raise NotImplementedError
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         """Reads a value from the given file.
         """
 
 
         cls.read(dbfile)
 
-    @classmethod
-    def to_bytes(cls, v):
+    @staticmethod
+    def to_bytes(v):
         """Returns a str (Python 2.x) or bytes (Python 3) representation of
         the given value. This is used for calculating node digests, so it
         should be unique but fast to calculate, and does not have to be
 
         raise NotImplementedError
 
-    @classmethod
-    def merge(cls, v1, v2):
+    @staticmethod
+    def merge(v1, v2):
         raise NotImplementedError
 
 
 class IntValues(Values):
-    @classmethod
-    def is_valid(self, v):
+    """Stores integer values in an FST.
+    """
+
+    @staticmethod
+    def is_valid(v):
         return isinstance(v, int) and v >= 0
 
-    @classmethod
-    def common(cls, v1, v2):
+    @staticmethod
+    def common(v1, v2):
         if v1 is None or v2 is None:
             return None
         if v1 == v2:
             return v1
         return min(v1, v2)
 
-    @classmethod
-    def add(cls, base, v):
+    @staticmethod
+    def add(base, v):
         if base is None:
             return v
         if v is None:
             return base
         return base + v
 
-    @classmethod
-    def subtract(cls, v, base):
+    @staticmethod
+    def subtract(v, base):
         if v is None:
             return None
         if base is None:
             return v
         return v - base
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         dbfile.write_uint(v)
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         return dbfile.read_uint()
 
-    @classmethod
-    def skip(cls, dbfile):
+    @staticmethod
+    def skip(dbfile):
         dbfile.seek(_INT_SIZE, 1)
 
-    @classmethod
-    def to_bytes(cls, v):
+    @staticmethod
+    def to_bytes(v):
         return pack_int(v)
 
 
 class SequenceValues(Values):
-    @classmethod
-    def is_valid(cls, v):
+    """Abstract base class for value types that store sequences.
+    """
+
+    @staticmethod
+    def is_valid(v):
         return isinstance(self, (list, tuple))
 
-    @classmethod
-    def common(cls, v1, v2):
+    @staticmethod
+    def common(v1, v2):
         if v1 is None or v2 is None:
             return None
 
             return v2
         return v1[:i]
 
-    @classmethod
-    def add(cls, prefix, v):
+    @staticmethod
+    def add(prefix, v):
         if prefix is None:
             return v
         if v is None:
             return prefix
         return prefix + v
 
-    @classmethod
-    def subtract(cls, v, prefix):
+    @staticmethod
+    def subtract(v, prefix):
         if prefix is None:
             return v
         if v is None:
             raise ValueError((v, prefix))
         return v[len(prefix):]
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         dbfile.write_pickle(v)
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         return dbfile.read_pickle()
 
-    @classmethod
-    def to_bytes(cls, v):
-        return b(str(v))
-
 
 class BytesValues(SequenceValues):
-    @classmethod
-    def is_valid(self, v):
+    """Stores bytes objects (str in Python 2.x) in an FST.
+    """
+
+    @staticmethod
+    def is_valid(v):
         return isinstance(v, bytes_type)
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         dbfile.write_int(len(v))
         dbfile.write(v)
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         length = dbfile.read_int()
         return dbfile.read(length)
 
-    @classmethod
-    def skip(cls, dbfile):
+    @staticmethod
+    def skip(dbfile):
         length = dbfile.read_int()
         dbfile.seek(length, 1)
 
-    @classmethod
-    def to_bytes(cls, v):
+    @staticmethod
+    def to_bytes(v):
         return v
 
 
 class ArrayValues(SequenceValues):
-    @classmethod
-    def is_valid(self, v):
+    """Stores array.array objects in an FST.
+    """
+
+    @staticmethod
+    def is_valid(v):
         return isinstance(v, array)
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         dbfile.write(b(v.typecode))
         dbfile.write_int(len(v))
         dbfile.write_array(v)
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         typecode = b(dbfile.read(1))
         length = dbfile.read_int()
         return dbfile.read_array(typecode, length)
 
-    @classmethod
+    @staticmethod
     def skip(dbfile):
         typecode = b(dbfile.read(1))
         length = dbfile.read_int()
         a = array(typecode)
         dbfile.seek(length * a.itemsize, 1)
 
-    @classmethod
-    def to_bytes(cls, v):
+    @staticmethod
+    def to_bytes(v):
         return v.tostring()
 
 
 class IntListValues(SequenceValues):
-    @classmethod
-    def is_valid(self, v):
+    """Stores lists of positive, increasing integers (that is, lists of
+    integers where each number is >= 0 and each number is greater than or equal
+    to the number that precedes it) in an FST.
+    """
+
+    @staticmethod
+    def is_valid(v):
         if isinstance(v, (list, tuple)):
             if len(v) < 2:
                 return True
             return True
         return False
 
-    @classmethod
-    def write(cls, dbfile, v):
+    @staticmethod
+    def write(dbfile, v):
         base = 0
         dbfile.write_varint(len(v))
         for x in v:
             dbfile.write_varint(delta)
             base = x
 
-    @classmethod
-    def read(cls, dbfile):
+    @staticmethod
+    def read(dbfile):
         length = dbfile.read_varint()
         result = []
         if length > 0:
                 result.append(base)
         return result
 
+    @staticmethod
+    def to_bytes(v):
+        return b(repr(v))
+
 
 # Node-like interface wrappers
 
 
 # Cursor
 
-class EndOfCursor(Exception):
-    pass
+class BaseCursor(object):
+    def is_active(self):
+        raise NotImplementedError
 
+    def label(self):
+        raise NotImplementedError
 
-class Cursor(object):
+    def prefix(self):
+        raise NotImplementedError
+
+    def prefix_bytes(self):
+        return emptybytes.join(self.prefix())
+
+    def peek_key(self):
+        for label in self.prefix():
+            yield label
+        c = self.copy()
+        while not c.stopped():
+            c.follow()
+            yield c.label()
+
+    def peek_key_bytes(self):
+        return emptybytes.join(self.peek_key())
+
+    def stopped(self):
+        raise NotImplementedError
+
+    def value(self):
+        raise NotImplementedError
+
+    def accept(self):
+        raise NotImplementedError
+
+    def at_last_arc(self):
+        raise NotImplementedError
+
+    def next_arc(self):
+        raise NotImplementedError
+
+    def follow(self):
+        raise NotImplementedError
+
+    def switch_to(self, label):
+        _label = self.label
+        _at_last_arc = self.at_last_arc
+        _next_arc = self.next_arc
+
+        while True:
+            thislabel = _label()
+            if thislabel == label:
+                return True
+            if thislabel > label or _at_last_arc():
+                return False
+            _next_arc()
+
+    def skip_to(self, key):
+        _accept = self.accept
+        _prefix = self.prefix
+        _next_arc = self.next_arc
+
+        keylist = list(key)
+        while True:
+            if _accept():
+                thiskey = list(_prefix())
+                if keylist == thiskey:
+                    return True
+                elif keylist > thiskey:
+                    return False
+            _next_arc()
+
+    def flatten(self):
+        _is_active = self.is_active
+        _accept = self.accept
+        _stopped = self.stopped
+        _follow = self.follow
+        _next_arc = self.next_arc
+        _prefix_bytes = self.prefix_bytes
+
+        if not _is_active():
+            raise InactiveCursor
+        while _is_active():
+            if _accept():
+                yield _prefix_bytes()
+            if not _stopped():
+                _follow()
+                continue
+            _next_arc()
+
+    def flatten_v(self):
+        for key in self.flatten():
+            yield key, self.value()
+
+    def find_path(self, path):
+        _switch_to = self.switch_to
+        _follow = self.follow
+        _stopped = self.stopped
+
+        first = True
+        for i, label in enumerate(path):
+            if not first:
+                _follow()
+            if not _switch_to(label):
+                return False
+            if _stopped():
+                if i < len(path) - 1:
+                    return False
+            first = False
+        return True
+
+    def follow_firsts(self):
+        while not self.stopped():
+            self.follow()
+
+    #    def follow_lasts(self):
+    #        while True:
+    #            while not self.stopped():
+    #                self.next_arc()
+    #            if self.current.target is not None:
+    #                self.follow()
+    #            else:
+    #                return
+
+
+class Cursor(BaseCursor):
     def __init__(self, graph, root=None, stack=None):
         self.graph = graph
         self.vtype = graph.vtype
         self.root = root if root is not None else graph.default_root()
         if stack:
             self.stack = stack
-            self.current = self.stack[-1]
         else:
             self.reset()
 
+    def _current_attr(self, name):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
+        return getattr(stack[-1], name)
+
+    def is_active(self):
+        return bool(self.stack)
+
+    def stopped(self):
+        return self._current_attr("target") is None
+
+    def accept(self):
+        return self._current_attr("accept")
+
+    def at_last_arc(self):
+        return self._current_attr("lastarc")
+
+    def label(self):
+        return self._current_attr("label")
+
     def reset(self):
         self.stack = []
         self.sums = [None]
     def copy(self):
         return self.__class__(self.graph, self.root, copy.deepcopy(self.stack))
 
-    def label(self):
-        return self.current.label
+    def prefix(self):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
+        return (arc.label for arc in stack)
 
-    def prefix(self):
-        return (arc.label for arc in self.stack)
+    # Override: more efficient implementation using graph methods directly
+    def peek_key(self):
+        if not self.stack:
+            raise InactiveCursor
 
-    def prefix_bytes(self):
-        return emptybytes.join(arc.label for arc in self.stack)
-
-    def peek_key(self):
         for label in self.prefix():
             yield label
-        arc = copy.copy(self.current)
+        arc = copy.copy(self.stack[-1])
         graph = self.graph
         while not arc.accept and arc.target is not None:
             graph.arc_at(arc.target, arc)
             yield arc.label
 
-    def peek_key_bytes(self):
-        return emptybytes.join(self.peek_key())
-
-    def target(self):
-        return self.current.target
-
     def value(self):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
         vtype = self.vtype
         if not vtype:
             raise Exception("No value type")
+
         v = self.sums[-1]
-        current = self.current
+        current = stack[-1]
         if current.value:
             v = vtype.add(v, current.value)
         if current.accept and current.acceptval is not None:
             v = vtype.add(v, current.acceptval)
         return v
 
-    def accept(self):
-        return self.current.accept
+    def next_arc(self):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
 
-    def lastarc(self):
-        return self.current.lastarc
+        while stack and stack[-1].lastarc:
+            self.pop()
+        if stack:
+            current = stack[-1]
+            self.graph.arc_at(current.endpos, current)
+            return current
 
-    def can_follow(self):
-        return self.current.target is not None
+    def follow(self):
+        address = self._current_attr("target")
+        if address is None:
+            raise Exception("Can't follow a stop arc")
+        self._push(self.graph.arc_at(address))
+        return self
+
+    # Override: more efficient implementation manipulating the stack
+    def skip_to(self, key):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
+
+        _follow = self.follow
+        _next_arc = self.next_arc
+
+        i = self._pop_to_prefix(key)
+        while stack and i < len(key):
+            curlabel = stack[-1].label
+            keylabel = key[i]
+            if curlabel == keylabel:
+                _follow()
+                i += 1
+            elif curlabel > keylabel:
+                return
+            else:
+                _next_arc()
+
+    # Override: more efficient implementation using find_arc
+    def switch_to(self, label):
+        stack = self.stack
+        if not stack:
+            raise InactiveCursor
+
+        current = stack[-1]
+        if label == current.label:
+            return True
+        else:
+            arc = self.graph.find_arc(current.endpos, label, current)
+            return arc
 
     def _push(self, arc):
         if self.vtype and self.stack:
             sums = self.sums
             sums.append(self.vtype.add(sums[-1], self.stack[-1].value))
         self.stack.append(arc)
-        self.current = arc
 
-    def _pop(self):
-        stack = self.stack
-        stack.pop()
+    def pop(self):
+        self.stack.pop()
         if self.vtype:
             self.sums.pop()
-        if stack:
-            self.current = stack[-1]
-        else:
-            self.current = None
-            raise EndOfCursor
 
-    def pop_to_prefix(self, key):
+    def _pop_to_prefix(self, key):
         stack = self.stack
+        if not stack:
+            raise InactiveCursor
+
         i = 0
         maxpre = min(len(stack), len(key))
         while i < maxpre and key[i] == stack[i].label:
             i += 1
         if stack[i].label > key[i]:
             self.current = None
-            raise EndOfCursor
+            return
         while len(stack) > i + 1:
-            self._pop()
+            self.pop()
         self.next_arc()
         return i
 
-    def skip_to(self, key):
-        i = self.pop_to_prefix(key)
-        while i < len(key):
-            curlabel = self.current.label
-            keylabel = key[i]
-            if curlabel == keylabel:
-                self.follow()
-                i += 1
-            elif curlabel > keylabel:
-                return
-            else:
-                self.next_arc()
 
-    def find_path(self, path):
-        _switch_to = self._switch_to
-        follow = self.follow
-
-        first = True
-        for i, label in enumerate(path):
-            if not first:
-                follow()
-            arc = _switch_to(label)
-            if arc is None:
-                return False
-            if arc.target is None:
-                if i < len(path) - 1:
-                    return False
-            first = False
-        return True
-
-    def follow(self):
-        address = self.current.target
-        if address is None:
-            raise Exception("Can't follow a stop arc")
-        self._push(self.graph.arc_at(address))
-
-    def _switch_to(self, label):
-        current = self.current
-        if label == current.label:
-            return current
-        else:
-            arc = self.graph.find_arc(current.endpos, label, current)
-            return arc
-
-    def follow_label(self, label):
-        arc = self._switch_to(label)
-        if arc:
-            self._push(arc)
-        return arc
-
-    def next_arc(self):
-        stack = self.stack
-        while stack and stack[-1].lastarc:
-            self._pop()
-        self.current = stack[-1]
-        self.graph.arc_at(self.current.endpos, self.current)
-
-    def flatten(self):
-        follow = self.follow
-        next_arc = self.next_arc
-        prefix_bytes = self.prefix_bytes
-
-        try:
-            while True:
-                current = self.current
-                if current.accept:
-                    yield prefix_bytes()
-                if current.target:
-                    follow()
-                    continue
-                next_arc()
-        except EndOfCursor:
-            return
-
-    def flatten_v(self):
-        for key in self.flatten():
-            yield key, self.value()
-
-    def follow_firsts(self):
-        while self.current.target is not None:
-            self.follow()
-
-    def follow_last(self):
-        while True:
-            while not self.current.lastarc:
-                self.next_arc()
-            if self.current.target is not None:
-                self.follow()
-            else:
-                return
+#class IntersectionCursor(BaseCursor):
+#    def __init__(self, a, b):
+#        self.a = a
+#        self.b = b
+#        self._active = self.a.is_active() and self.b.is_active() and self._sync()
+#
+#    def copy(self):
+#        return self.__class__(self.a.copy(), self.b.copy())
+#
+#    def _match_labels(self, a, b):
+#        while True:
+#            alab = a.label()
+#            blab = b.label()
+#            if alab == blab:
+#                return True
+#            elif a.at_last_arc() or b.at_last_arc():
+#                return False
+#            elif alab < blab:
+#                a.switch_to(blab)
+#            elif blab < alab:
+#                b.switch_to(alab)
+#
+#    def _sync(self):
+#        a = self.a
+#        b = self.b
+#        while True:
+#            if not self._match_labels(a, b):
+#                return False
+#            ac = a.copy()
+#            bc = b.copy()
+#            if self._match_labels(ac, bc):
+#                return True
+#
+#            if a.at_last_arc() or b.at_last_arc():
+#                return False
+#            a.next_arc()
+#            b.next_arc()
+#
+#    def is_active(self):
+#        return self._active
+#
+#    def label(self):
+#        if not self._active:
+#            raise InactiveCursor
+#        a = self.a.label()
+#        b = self.b.label()
+#        assert a == b
+#        return a
+#
+#    def stopped(self):
+#        if not self._active:
+#            raise InactiveCursor
+#        return self.a.stopped() or self.b.stopped()
+#
+#    def accept(self):
+#        if not self._active:
+#            raise InactiveCursor
+#        return self.a.accept() and self.b.accept()
+#
+#    def prefix(self):
+#        for alab, blab in izip(self.a.prefix(), self.b.prefix()):
+#            assert alab == blab
+#            yield alab
+#
+#    def at_last_arc(self):
+#        return self.a.at_last_arc() or self.b.at_last_arc()
+#
+#    def pop(self):
+#        self.a.pop()
+#        self.b.pop()
+#        if not (self.a.is_active() and self.b.is_active()):
+#            self._active = False
+#
+#    def next_arc(self):
+#        if not self._active:
+#            raise InactiveCursor
+#
+#        synced = False
+#        while not synced:
+#            if not (self.a.is_active() and self.b.is_active()):
+#                self._active = False
+#                return
+#            if self.a.at_last_arc() or self.b.at_last_arc():
+#                self.pop()
+#            self.a.next_arc()
+#            self.b.next_arc()
+#            synced = self._sync()
+#
+#    def follow(self):
+#        self.a.follow()
+#        self.b.follow()
+#        self._sync()
+#        return self
 
 
 # Graph reader
         return dict((arc.label, copy.copy(arc))
                     for arc in self.iter_arcs(address))
 
+    def find_path(self, path, arc=None):
+        if arc:
+            address = arc.target
+        else:
+            arc = Arc()
+            address = self._root
+
+        for label in path:
+            if address is None:
+                return None
+            if not self.find_arc(address, label, arc):
+                return None
+            address = arc.target
+        return arc
+
 
 class GraphReader(BaseGraphReader):
     def __init__(self, dbfile, rootname=None, labelsize=1, vtype=None,
 # Within edit distance function
 
 def within(graph, text, k=1, prefix=0, address=None):
+    """Yields a series of keys in the given graph within ``k`` edit distance of
+    ``text``. If ``prefix`` is greater than 0, all keys must match the first
+    ``prefix`` characters of ``text``.
+    """
+
     if address is None:
         address = graph._root
 
     accept = False
     if prefix:
         sofar = text[:prefix]
-        # This function duplicates a lot of arc-following functionality from
-        # Cursor, but here we have to instantiate a Cursor just to use its
-        # find_path method.
-        # TODO: find a better way
-        cur = Cursor(graph, address)
-        if not cur.find_path(sofar):
+        arc = graph.find_path(sofar)
+        if arc is None:
             return
-        address, accept = cur.target(), cur.accept()
+        address = arc.target
+        accept = arc.accept
 
     stack = [(address, k, prefix, sofar, accept)]
     seen = set()

File tests/test_dawg.py

     cur = dawg.Cursor(greader(st))
     assert_equal(list(cur.flatten()), ["alfa", "bravo", "charlie"])
 
+def test_inactive_raise():
+    st = gwrite(enlist("alfa bravo charlie"))
+    cur = dawg.Cursor(greader(st))
+    while cur.is_active():
+        cur.next_arc()
+    assert_raises(dawg.InactiveCursor, cur.label)
+    assert_raises(dawg.InactiveCursor, cur.prefix)
+    assert_raises(dawg.InactiveCursor, cur.prefix_bytes)
+    assert_raises(dawg.InactiveCursor, list, cur.peek_key())
+    assert_raises(dawg.InactiveCursor, cur.peek_key_bytes)
+    assert_raises(dawg.InactiveCursor, cur.stopped)
+    assert_raises(dawg.InactiveCursor, cur.value)
+    assert_raises(dawg.InactiveCursor, cur.accept)
+    assert_raises(dawg.InactiveCursor, cur.at_last_arc)
+    assert_raises(dawg.InactiveCursor, cur.next_arc)
+    assert_raises(dawg.InactiveCursor, cur.follow)
+    assert_raises(dawg.InactiveCursor, cur.switch_to, b("a"))
+    assert_raises(dawg.InactiveCursor, cur.skip_to, b("a"))
+    assert_raises(dawg.InactiveCursor, list, cur.flatten())
+    assert_raises(dawg.InactiveCursor, list, cur.flatten_v())
+    assert_raises(dawg.InactiveCursor, cur.find_path, b("a"))
+    assert_raises(dawg.InactiveCursor, cur.follow_firsts)
+
 def test_types():
     st = RamStorage()
 
 
     cur1.find_path(b("blo"))
     cur2.find_path(b("glo"))
-    assert_equal(cur1.current.target, cur2.current.target)
+    assert_equal(cur1.stack[-1].target, cur2.stack[-1].target)
 
 def test_fields():
     with TempStorage() as st:
     cur.follow_firsts()
     assert_equal(cur.prefix_bytes(), b("abcd"))
     assert cur.accept()
-    cur.pop_to_prefix("abzz")
+    cur._pop_to_prefix("abzz")
     assert_equal(cur.prefix_bytes(), b("abf"))
 
     cur = gr.cursor()
 
     cur = gr.cursor()
     cur.follow_firsts()
-    assert_raises(dawg.EndOfCursor, cur.skip_to, b("z"))
+    cur.skip_to(b("z"))
+    assert not cur.is_active()
 
-    cur = gr.cursor()
-    cur.follow_lasts()
-    assert_equal(cur.prefix_bytes(), b("wxyz"))