Commits

Matt Chaput committed 4eddae8

dawg.within() didn't pass an address to find_path(), if there was no default root, it would always abort. Fixes issue #261.
Fixed unbalanced start/finish_field calls to GraphWriter.

Comments (0)

Files changed (5)

src/whoosh/codec/whoosh2.py

         self.inlinelimit = inlinelimit
         self.block = None
         self.terminfo = None
+        self._infield = False
 
     def _make_dawg_files(self):
         dawgfile = self.segment.create_file(self.storage, W2Codec.DAWG_EXT)
         self.field = fieldobj
         self.format = fieldobj.format
         self.spelling = fieldobj.spelling and not fieldobj.separate_spelling()
+        self._dawgfield = False
         if self.spelling or fieldobj.separate_spelling():
             if self.dawg is None:
                 self._make_dawg_files()
             self.dawg.start_field(fieldname)
+            self._dawgfield = True
+        self._infield = True
 
     def start_term(self, text):
         if self.block is not None:
         self.termsindex.add((self.fieldname, self.text), terminfo)
 
     def finish_field(self):
-        if self.dawg:
+        if not self._infield:
+            raise Exception("Called finish_field before start_field")
+        self._infield = False
+
+        if self._dawgfield:
             self.dawg.finish_field()
+            self._dawgfield = False
 
     def close(self):
         self.termsindex.close()

src/whoosh/spelling.py

         dbfile = StructFile(dbfile)
 
     gw = dawg.GraphWriter(dbfile)
+    gw.start_field(fieldname)
     for word in wordlist:
         if strip:
             word = word.strip()
         gw.insert(word)
+    gw.finish_field()
     gw.close()
 
 

src/whoosh/support/dawg.py

 
 import sys, copy
 from array import array
-from hashlib import sha1  #@UnresolvedImport
+from hashlib import sha1  # @UnresolvedImport
 
 from whoosh.compat import (b, u, BytesIO, xrange, iteritems, iterkeys,
                            bytes_type, text_type, izip, array_tobytes)
         return i
 
 
-#class IntersectionCursor(BaseCursor):
-#    def __init__(self, a, b):
-#        self.a = a
-#        self.b = b
-#        self._active = self.a.is_active() and self.b.is_active() and self._sync()
-#
-#    def copy(self):
-#        return self.__class__(self.a.copy(), self.b.copy())
-#
-#    def _match_labels(self, a, b):
-#        while True:
-#            alab = a.label()
-#            blab = b.label()
-#            if alab == blab:
-#                return True
-#            elif a.at_last_arc() or b.at_last_arc():
-#                return False
-#            elif alab < blab:
-#                a.switch_to(blab)
-#            elif blab < alab:
-#                b.switch_to(alab)
-#
-#    def _sync(self):
-#        a = self.a
-#        b = self.b
-#        while True:
-#            if not self._match_labels(a, b):
-#                return False
-#            ac = a.copy()
-#            bc = b.copy()
-#            if self._match_labels(ac, bc):
-#                return True
-#
-#            if a.at_last_arc() or b.at_last_arc():
-#                return False
-#            a.next_arc()
-#            b.next_arc()
-#
-#    def is_active(self):
-#        return self._active
-#
-#    def label(self):
-#        if not self._active:
-#            raise InactiveCursor
-#        a = self.a.label()
-#        b = self.b.label()
-#        assert a == b
-#        return a
-#
-#    def stopped(self):
-#        if not self._active:
-#            raise InactiveCursor
-#        return self.a.stopped() or self.b.stopped()
-#
-#    def accept(self):
-#        if not self._active:
-#            raise InactiveCursor
-#        return self.a.accept() and self.b.accept()
-#
-#    def prefix(self):
-#        for alab, blab in izip(self.a.prefix(), self.b.prefix()):
-#            assert alab == blab
-#            yield alab
-#
-#    def at_last_arc(self):
-#        return self.a.at_last_arc() or self.b.at_last_arc()
-#
-#    def pop(self):
-#        self.a.pop()
-#        self.b.pop()
-#        if not (self.a.is_active() and self.b.is_active()):
-#            self._active = False
-#
-#    def next_arc(self):
-#        if not self._active:
-#            raise InactiveCursor
-#
-#        synced = False
-#        while not synced:
-#            if not (self.a.is_active() and self.b.is_active()):
-#                self._active = False
-#                return
-#            if self.a.at_last_arc() or self.b.at_last_arc():
-#                self.pop()
-#            self.a.next_arc()
-#            self.b.next_arc()
-#            synced = self._sync()
-#
-#    def follow(self):
-#        self.a.follow()
-#        self.b.follow()
-#        self._sync()
-#        return self
-
-
 class UncompiledNode(object):
     # Represents an "in-memory" node used by the GraphWriter before it is
     # written to disk.
         dbfile.write_int(self.version)
         dbfile.write_uint(0)
 
-        self.fieldname = None
-        self.start_field("_")
+        self._infield = False
 
     def start_field(self, fieldname):
         """Starts a new graph for the given field.
 
         if not fieldname:
             raise ValueError("Field name cannot be equivalent to False")
-        if self.fieldname is not None:
+        if self._infield:
             self.finish_field()
         self.fieldname = fieldname
         self.seen = {}
         self.nodes = [UncompiledNode(self)]
         self.lastkey = ''
         self._inserted = False
+        self._infield = True
 
     def finish_field(self):
         """Finishes the graph for the current field.
         """
 
+        if not self._infield:
+            raise Exception("Called finish_field before start_field")
+        self._infield = False
         if self._inserted:
             self.fieldroots[self.fieldname] = self._finish()
         self.fieldname = None
             a value here will raise an error.
         """
 
-        if self.fieldname is None:
+        if not self._infield:
             raise Exception("Inserted %r before starting a field" % key)
         self._inserted = True
         key = to_labels(key)  # Python 3 sucks
         return dict((arc.label, copy.copy(arc))
                     for arc in self.iter_arcs(address))
 
-    def find_path(self, path, arc=None):
+    def find_path(self, path, arc=None, address=None):
         path = to_labels(path)
 
         if arc:
             address = arc.target
         else:
             arc = Arc()
+
+        if address is None:
             address = self._root
 
         for label in path:
     accept = False
     if prefix:
         prefixchars = text[:prefix]
-        arc = graph.find_path(prefixchars)
+        arc = graph.find_path(prefixchars, address=address)
         if arc is None:
             return
         sofar = emptybytes.join(prefixchars)
         else:
             out.write(" " * 6)
         out.write("  " * tab)
-        out.write("%r %r %s %r\n" % (arc.label, arc.target, arc.accept, arc.value))
+        out.write("%r %r %s %r\n"
+                  % (arc.label, arc.target, arc.accept, arc.value))
         if arc.target is not None:
             dump_graph(graph, arc.target, tab + 1, out=out)
 

tests/test_dawg.py

     st = st or RamStorage()
     f = st.create_file("test")
     gw = dawg.GraphWriter(f)
+    gw.start_field("_")
     for key in keys:
         gw.insert(key)
+    gw.finish_field()
     gw.close()
     return st
 
 
 def test_empty_key():
     gw = dawg.GraphWriter(RamStorage().create_file("test"))
+    gw.start_field("_")
     assert_raises(KeyError, gw.insert, b(""))
     assert_raises(KeyError, gw.insert, "")
     assert_raises(KeyError, gw.insert, u(""))
 def test_keys_out_of_order():
     f = RamStorage().create_file("test")
     gw = dawg.GraphWriter(f)
+    gw.start_field("test")
     gw.insert("alfa")
     assert_raises(KeyError, gw.insert, "abba")
 
     with TempStorage() as st:
         f = st.create_file("test")
         gw = dawg.GraphWriter(f, vtype=t)
+        gw.start_field("_")
         for key, value in domain:
             gw.insert(key, value)
+        gw.finish_field()
         gw.close()
 
         f = st.open_file("test")
 
     st = RamStorage()
     gw = dawg.GraphWriter(st.create_file("test"))
+    gw.start_field("test")
     for key in domain:
         gw.insert(key)
     gw.close()
 
     st = RamStorage()
     gw = dawg.GraphWriter(st.create_file("test"))
+    gw.start_field("test")
     for key in domain:
         gw.insert(key)
     gw.close()
 
     st = RamStorage()
     gw = dawg.GraphWriter(st.create_file("test"))
+    gw.start_field("test")
     for key in domain:
         gw.insert(key)
     gw.close()

tests/test_spelling.py

 from whoosh.filedb.filestore import RamStorage
 from whoosh.qparser import QueryParser
 from whoosh.support import dawg
-from whoosh.support.testing import TempStorage
+from whoosh.support.testing import TempStorage, TempIndex
 
 
 def words_to_corrector(words):
     assert_not_equal(gc.suggest("bone")[0], "bone")
 
 
+def test_suggest_prefix():
+    domain = ("Shoot To Kill",
+              "Bloom, Split and Deviate",
+              "Rankle the Seas and the Skies",
+              "Lightning Flash Flame Shell",
+              "Flower Wind Rage and Flower God Roar, Heavenly Wind Rage and "
+              "Heavenly Demon Sneer",
+              "All Waves, Rise now and Become my Shield, Lightning, Strike "
+              "now and Become my Blade",
+              "Cry, Raise Your Head, Rain Without end",
+              "Sting All Enemies To Death",
+              "Reduce All Creation to Ash",
+              "Sit Upon the Frozen Heavens",
+              "Call forth the Twilight")
 
+    schema = fields.Schema(content=fields.TEXT(stored=True, spelling=True),
+                           quick=fields.NGRAM(maxsize=10, stored=True))
 
+    with TempIndex(schema, "sugprefix") as ix:
+        with ix.writer() as w:
+            for item in domain:
+                content = u(item)
+                w.add_document(content=content, quick=content)
 
+        with ix.searcher() as s:
+            sugs = s.suggest("content", u("ra"), maxdist=2, prefix=2)
+            assert_equal(sugs, ['rage', 'rain'])
 
+            sugs = s.suggest("content", "ra", maxdist=2, prefix=1)
+            assert_equal(sugs, ["rage", "rain", "roar"])
 
+
+def test_prefix_address():
+    fieldtype = fields.TEXT(spelling=True)
+    schema = fields.Schema(f1=fieldtype, f2=fieldtype)
+    with TempIndex(schema, "prefixaddr") as ix:
+        with ix.writer() as w:
+            w.add_document(f1=u("aabc aawx aaqr aade"),
+                           f2=u("aa12 aa34 aa56 aa78"))
+
+        with ix.searcher() as s:
+            sugs = s.suggest("f1", u("aa"), maxdist=2, prefix=2)
+            assert_equal(sorted(sugs), ["aabc", "aade", "aaqr", "aawx"])
+
+            sugs = s.suggest("f2", u("aa"), maxdist=2, prefix=2)
+            assert_equal(sorted(sugs), ["aa12", "aa34", "aa56", "aa78"])
+
+
+
+