Commits

Matt Chaput  committed 1184038

More work on spelling.

  • Participants
  • Parent commits b8fc9b0
  • Branches dawg

Comments (0)

Files changed (7)

File src/whoosh/filedb/filereading.py

         self.dawg = None
         if any(field.spelling for field in self.schema):
             fname = segment.dawg_filename
-            if not self.storage.file_exists(fname):
-                spelled = [fn for fn, field in self.schema.items() if field.spelling]
-                raise Exception("Field(s) %r have spelling=True but DAWG file %r not found" % (spelled, fname))
-            
-            dawgfile = self.storage.open_file(fname, mapped=False)
-            self.dawg = DawgReader(dawgfile).root
+            if self.storage.file_exists(fname):
+                dawgfile = self.storage.open_file(fname, mapped=False)
+                self.dawg = DawgReader(dawgfile, expand=False).root
         
         self.dc = segment.doc_count_all()
         assert self.dc == self.storedfields.length

File src/whoosh/filedb/filewriting.py

         dawg = None
         if any(field.spelling for field in self.schema):
             df = self.storage.create_file(segment.dawg_filename)
-            dawg = DawgWriter(df)
+            dawg = DawgWriter(df, reduce_root=False)
         
         # Terms index
         tf = self.storage.create_file(segment.termsindex_filename)
         self.postwriter.close()
         if self.dawg:
             self.dawg.write()
-        
 
-def add_spelling(ix, fieldnames, force=False):
+
+def add_spelling(ix, fieldnames, commit=True):
     """Adds spelling files to an existing index that was created without
     them, and modifies the schema so the given fields have the ``spelling``
     attribute. Only works on filedb indexes.
     schema = writer.schema
     segments = writer.segments
     
-    for fieldname in fieldnames:
-        schema[fieldname].spelling = True
-    
     for segment in segments:
         filename = segment.dawg_filename
-        if storage.file_exists(filename) and not force:
-            continue
-        
         r = SegmentReader(storage, schema, segment)
         f = storage.create_file(filename)
-        dw = DawgWriter(f)
+        dw = DawgWriter(reduce_root=False)
         for fieldname in fieldnames:
             ft = (fieldname, )
             for word in r.lexicon(fieldname):
                 dw.insert(ft + tuple(word))
-        dw.close()
-    writer.commit()
+        dw.write(f)
+    
+    for fieldname in fieldnames:
+        schema[fieldname].spelling = True
+    
+    if commit:
+        writer.commit(merge=False)
 
-
+    
         
 
 

File src/whoosh/reading.py

             for word in within(node, text, maxdist, prefix=prefix, seen=seen):
                 yield word
         else:
+            if seen is None:
+                seen = set()
             for word in self.expand_prefix(fieldname, text[:prefix]):
-                if word == text:
-                    yield text
-                elif distance(word, text, limit=maxdist) <= maxdist:
+                if word in seen:
+                    continue
+                if word == text or distance(word, text, limit=maxdist) <= maxdist:
                     yield word
+                    seen.add(word)
     
     def most_frequent_terms(self, fieldname, number=5, prefix=''):
         """Returns the top 'number' most frequent terms in the given field as a
         segmentnum, segmentdoc = self._segment_and_docnum(docnum)
         return self.readers[segmentnum].vector_as(astype, segmentdoc, fieldname)
 
+    def has_word_graph(self, fieldname):
+        return any(r.has_word_graph(fieldname) for r in self.readers)
+    
+    def word_graph(self, fieldname):
+        from whoosh.support.dawg import NullNode, UnionNode
+        from whoosh.util import make_binary_tree
+        
+        graphs = [r.word_graph(fieldname) for r in self.readers
+                  if r.has_word_graph(fieldname)]
+        if not graphs:
+            return NullNode()
+        if len(graphs) == 1:
+            return graphs[0]
+        return make_binary_tree(UnionNode, graphs)
+
     def format(self, fieldname):
         for r in self.readers:
             fmt = r.format(fieldname)

File src/whoosh/searching.py

 
 from whoosh import classify, highlight, query, scoring
 from whoosh.reading import TermNotFound
-from whoosh.spelling import suggest
 from whoosh.support.bitvector import BitSet, BitVector
 from whoosh.util import now, lru_cache
 

File src/whoosh/spelling.py

     """Ranks suggestions by the edit distance.
     """
     
-    return cost
+    return (cost, 0)
 
 
 class Corrector(object):
     for the field and use :func:`suggest` instead of this object.
     """
     
-    def score(self, word, cost):
-        """Returns a rank value (where lower values are better suggestions)
-        for the given word and "cost" (usually edit distance).
-        """
-        
-        return cost
-        
     def suggest(self, text, limit=5, maxdist=2, prefix=0):
         """
         :param text: the text to check.
             list of words.
         """
         
-        score = self.score
         suggestions = self.suggestions
         
         heap = []
         seen = set()
         for k in xrange(1, maxdist+1):
-            for sug in suggestions(text, k, prefix, seen):
-                item = (score(sug, k), sug)
+            for item in suggestions(text, k, prefix, seen):
                 if len(heap) < limit:
                     heappush(heap, item)
                 elif item < heap[0]:
         return [sug for _, sug in sorted(heap)]
         
     def suggestions(self, text, maxdist, prefix, seen):
-        """Low-level method that yields a series of ("suggestion", cost) tuples.
+        """Low-level method that yields a series of (score, "suggestion")
+        tuples.
         
         :param text: the text to check.
         :param maxdist: the maximum edit distance.
         self.reader = reader
         self.fieldname = fieldname
     
-    def score(self, word, cost):
-        return (cost, 0 - self.reader.frequency(self.fieldname, word))
+    def suggestions(self, text, maxdist, prefix, seen):
+        fieldname = self.fieldname
+        freq = self.reader.frequency
+        for sug in self.reader.terms_within(self.fieldname, text, maxdist,
+                                            prefix=prefix, seen=seen):
+            yield ((maxdist, 0 - freq(fieldname, sug)), sug)
+
+
+class GraphCorrector(Corrector):
+    """Suggests corrections based on the content of a word list.
+    
+    By default ranks suggestions based on the edit distance.
+    """
+
+    def __init__(self, word_graph, ranking=None):
+        self.word_graph = word_graph
+        self.ranking = ranking or simple_scorer
     
     def suggestions(self, text, maxdist, prefix, seen):
-        return self.reader.terms_within(self.fieldname, text, maxdist,
-                                        prefix=prefix, seen=seen)
+        ranking = self.ranking
+        for sug in dawg.within(self.word_graph, text, maxdist, prefix=prefix,
+                               seen=seen):
+            yield (ranking(sug, maxdist), sug)
+    
+    def save(self, filename):
+        f = open(filename, "wb")
+        self.word_graph.write(f)
+        f.close()
+    
+    @classmethod
+    def from_word_list(cls, wordlist, ranking=None, fieldname=""):
+        dw = dawg.DawgWriter()
+        for word in wordlist:
+            dw.insert(word)
+        return cls(dw.root, ranking=ranking)
+    
+    @classmethod
+    def from_graph_file(cls, dbfile, ranking=None, fieldname=""):
+        dr = dawg.DawgReader(dbfile)
+        return cls(dr.root, ranking=ranking)
+    
+
+class MultiCorrector(Corrector):
+    """Merges suggestions from a list of sub-correctors.
+    """
+    
+    def __init__(self, correctors):
+        self.correctors = correctors
         
+    def suggestions(self, text, maxdist, prefix, seen):
+        for corr in self.correctors:
+            for item in corr.suggestions(text, maxdist, prefix, seen):
+                yield item
+
 
 def wordlist_to_graph_file(wordlist, dbfile):
     """Writes a word graph file from a list of words.
     dbfile.close()
 
 
-class GraphCorrector(Corrector):
-    """Suggests corrections based on the content of a word list.
-    
-    By default ranks suggestions based on the edit distance.
-    """
-
-    def __init__(self, word_graph, ranking=None):
-        self.word_graph = word_graph
-        self.score = ranking or simple_scorer
-    
-    def suggestions(self, text, maxdist, prefix, seen):
-        return dawg.within(self.word_graph, text, maxdist, prefix=prefix,
-                           seen=seen)
-    
-    def save(self, filename):
-        f = open(filename, "wb")
-        self.word_graph.write(f)
-        f.close()
-    
-    @classmethod
-    def from_word_list(cls, wordlist, ranking=None, fieldname=""):
-        dw = dawg.DawgWriter()
-        for word in wordlist:
-            dw.add(fieldname, word)
-        return cls(dw.root, ranking=ranking)
-    
-    @classmethod
-    def from_graph_file(cls, dbfile, ranking=None, fieldname=""):
-        dr = dawg.DawgReader(dbfile)
-        return cls(dr.field_root(fieldname), ranking=ranking)
-    
-
-# Old, obsolete spell checker
+# Old, obsolete spell checker - DO NOT USE
 
 class SpellChecker(object):
     """This feature is obsolete.

File src/whoosh/support/dawg.py

     def __len__(self):
         raise NotImplementedError
     
-    def edge(self, key):
+    def keys(self):
+        return list(self)
+    
+    def edge(self, key, expand=True):
         raise NotImplementedError
     
     def all_edges(self):
         return len(self) + sum(self.edge(key).edge_count() for key in self)
 
 
+class NullNode(BaseNode):
+    """An empty node. This is sometimes useful for representing an empty graph.
+    """
+    
+    final = False
+    
+    def __containts__(self, key):
+        return False
+    
+    def __iter__(self):
+        return iter([])
+    
+    def __len__(self):
+        return 0
+    
+    def edge(self, key, expand=True):
+        raise KeyError(key)
+    
+    def all_edges(self):
+        return {}
+    
+    def edge_count(self):
+        return 0
+
+
 class BuildNode(object):
     def __init__(self):
         self.final = False
         self._hash = None  # Invalidate the cached hash value
         self._edges[key] = node
     
-    def edge(self, key):
+    def edge(self, key, expand=True):
         return self._edges[key]
     
     def all_edges(self):
     
 
 class DawgWriter(object):
-    def __init__(self, dbfile=None, reduced=True):
+    def __init__(self, dbfile=None, reduced=True, reduce_root=True):
         self.dbfile = dbfile
         self.reduced = reduced
+        self.reduce_root = reduce_root
         
         self.lastword = ""
         # List of nodes that have not been checked for duplication.
 
     def write(self, dbfile=None):
         dbfile = self.dbfile or dbfile
+        dbfile.write("GR01")  # Magic number
         dbfile.write_int(0)  # File flags
         dbfile.write_uint(0)  # Pointer to root node
         
         self._minimize(0)
         root = self.root
         if self.reduced:
-            reduce(root)
-        offset = self._write_node(root)
-        self._reset()
+            if self.reduce_root:
+                reduce(root)
+            else:
+                for key in root:
+                    v = root.edge(key)
+                    reduce(v)
+        offset = self._write_node(dbfile, root)
         
+        # Seek back and write the pointer to the root node
         dbfile.flush()
-        dbfile.seek(_INT_SIZE)
+        dbfile.seek(_INT_SIZE * 2)
         dbfile.write_uint(offset)
         dbfile.close()
     
-    def _write_node(self, node):
-        dbfile = self.dbfile
+    def _write_node(self, dbfile, node):
         keys = node._edges.keys()
         ptrs = array("I")
         for key in keys:
             if id(sn) in self.offsets:
                 ptrs.append(self.offsets[id(sn)])
             else:
-                ptr = self._write_node(sn)
+                ptr = self._write_node(dbfile, sn)
                 self.offsets[id(sn)] = ptr
                 ptrs.append(ptr)
         
 
 
 class DiskNode(BaseNode):
-    def __init__(self, dr, offset):
+    def __init__(self, dr, offset, expand=True):
         self.dr = dr
         self.id = offset
         
                     self._edges[unichr(charnum)] = ptr
                 else:
                     key = utf8decode(dbfile.read_string())[0]
-                    if len(key) > 1:
+                    if len(key) > 1 and expand:
                         self._edges[key[0]] = PatNode(dr, key[1:], ptr)
                     else:
                         self._edges[key] = ptr
-            
+    
     def __repr__(self):
-        return "<%s:%s %s>" % (self.id, "".join(self._edges.keys()), self.final)
+        return "<%s %s:%s %s>" % (self.__class__.__name__, self.id, ",".join(self._edges.keys()), self.final)
     
     def __contains__(self, key):
         return key in self._edges
     def __len__(self):
         return len(self._edges)
     
-    def edge(self, key):
+    def edge(self, key, expand=True):
         v = self._edges[key]
         if not isinstance(v, BaseNode):
             # Convert pointer to disk node
-            v = DiskNode(self.dr, v)
+            v = DiskNode(self.dr, v, expand=expand)
             #if self.caching:
             self._edges[key] = v
         return v
         else:
             return 0
     
-    def edge(self, key):
+    def edge(self, key, expand=True):
         label = self.label
         i = self.i
         if i < len(label) and key == label[i]:
     """Makes two graphs appear to be the union of the two graphs.
     """
     
-    def edge(self, key):
+    def edge(self, key, expand=True):
         a = self.a
         b = self.b
         if key in a and key in b:
     """Makes two graphs appear to be the intersection of the two graphs.
     """
     
-    def edge(self, key):
+    def edge(self, key, expand=True):
         a = self.a
         b = self.b
         if key in a and key in b:
 # Reader for disk-based graph files
 
 class DawgReader(object):
-    def __init__(self, dbfile):
+    def __init__(self, dbfile, expand=True):
         self.dbfile = dbfile
         
         dbfile.seek(0)
+        magic = dbfile.read(4)
+        assert magic == "GR01"
         self.fileflags = dbfile.read_int()
-        self.root = DiskNode(self, dbfile.read_uint())
-    
+        self.root = DiskNode(self, dbfile.read_uint(), expand=expand)
+
 
 # Functions
 
     return c + sum(edge_count(node.edge(key)) for key in node)
 
 
+def flatten(node, sofar=""):
+    if node.final:
+        yield sofar
+    for key in sorted(node):
+        for word in flatten(node.edge(key, expand=False), sofar + key):
+            yield word
+
+
 def dump_dawg(node, tab=0):
     print "  " * tab, id(node), node.final
     for key in node:
 
 
 
+    
 
 
 
 
 
 
+

File tests/test_spelling.py

+from __future__ import with_statement
 from nose.tools import assert_equal, assert_not_equal
 
-from whoosh import spelling
+from whoosh import fields, spelling
 from whoosh.filedb.filestore import RamStorage
 
 
-def test_spelling():
-    st = RamStorage()
+def test_graph_corrector():
+    wordlist = sorted(["render", "animation", "animate", "shader",
+                       "shading", "zebra", "koala", "lamppost",
+                       "ready", "kismet", "reaction", "page",
+                       "delete", "quick", "brown", "fox", "jumped",
+                       "over", "lazy", "dog", "wicked", "erase",
+                       "red", "team", "yellow", "under", "interest",
+                       "open", "print", "acrid", "sear", "deaf",
+                       "feed", "grow", "heal", "jolly", "kilt",
+                       "low", "zone", "xylophone", "crown",
+                       "vale", "brown", "neat", "meat", "reduction",
+                       "blunder", "preaction"])
     
-    sp = spelling.SpellChecker(st, mingram=2)
+    sp = spelling.GraphCorrector.from_word_list(wordlist)
+    sugs = sp.suggest("reoction", maxdist=2)
+    assert_equal(sugs, ["reaction", "preaction", "reduction"])
+
+def test_reader_corrector_nograph():
+    schema = fields.Schema(text=fields.TEXT)
+    ix = RamStorage().create_index(schema)
+    w = ix.writer()
+    w.add_document(text=u"render zorro kaori postal")
+    w.add_document(text=u"reader zebra koala pastry")
+    w.add_document(text=u"leader libra ooala paster")
+    w.add_document(text=u"feeder lorry zoala baster")
+    w.commit()
     
-    wordlist = ["render", "animation", "animate", "shader",
-                "shading", "zebra", "koala", "lamppost",
-                "ready", "kismet", "reaction", "page",
-                "delete", "quick", "brown", "fox", "jumped",
-                "over", "lazy", "dog", "wicked", "erase",
-                "red", "team", "yellow", "under", "interest",
-                "open", "print", "acrid", "sear", "deaf",
-                "feed", "grow", "heal", "jolly", "kilt",
-                "low", "zone", "xylophone", "crown",
-                "vale", "brown", "neat", "meat", "reduction",
-                "blunder", "preaction"]
+    with ix.reader() as r:
+        sp = spelling.ReaderCorrector(r, "text")
+        assert_equal(sp.suggest(u"kaola", maxdist=1), [u'koala'])
+        assert_equal(sp.suggest(u"kaola", maxdist=2), [u'koala', u'kaori', u'ooala', u'zoala'])
+
+def test_reader_corrector():
+    schema = fields.Schema(text=fields.TEXT(spelling=True))
+    ix = RamStorage().create_index(schema)
+    w = ix.writer()
+    w.add_document(text=u"render zorro kaori postal")
+    w.add_document(text=u"reader zebra koala pastry")
+    w.add_document(text=u"leader libra ooala paster")
+    w.add_document(text=u"feeder lorry zoala baster")
+    w.commit()
     
-    sp.add_words([unicode(w) for w in wordlist])
+    with ix.reader() as r:
+        assert r.has_word_graph("text")
+        sp = spelling.ReaderCorrector(r, "text")
+        assert_equal(sp.suggest(u"kaola", maxdist=1), [u'koala'])
+        assert_equal(sp.suggest(u"kaola", maxdist=2), [u'koala', u'kaori', u'ooala', u'zoala'])
+
+def test_add_spelling():
+    schema = fields.Schema(text1=fields.TEXT, text2=fields.TEXT)
+    ix = RamStorage().create_index(schema)
+    w = ix.writer()
+    w.add_document(text1=u"render zorro kaori postal", text2=u"alfa")
+    w.add_document(text1=u"reader zebra koala pastry", text2=u"alpa")
+    w.add_document(text1=u"leader libra ooala paster", text2=u"alpha")
+    w.add_document(text1=u"feeder lorry zoala baster", text2=u"olfo")
+    w.commit()
     
-    sugs = sp.suggest(u"reoction")
-    assert_not_equal(len(sugs), 0)
-    assert_equal(sugs, [u"reaction", u"reduction", u"preaction"])
+    with ix.reader() as r:
+        assert not r.has_word_graph("text1")
+        assert not r.has_word_graph("text2")
     
-def test_suggestionsandscores():
-    st = RamStorage()
-    sp = spelling.SpellChecker(st, mingram=2)
+    from whoosh.filedb.filewriting import add_spelling
+    add_spelling(ix, ["text1", "text2"])
     
-    words = [("alfa", 10), ("bravo", 9), ("charlie", 8), ("delta", 7),
-             ("echo", 6), ("foxtrot", 5), ("golf", 4), ("hotel", 3),
-             ("india", 2), ("juliet", 1)]
-    sp.add_scored_words((unicode(w), s) for w, s in words)
+    with ix.reader() as r:
+        assert r.has_word_graph("text1")
+        assert r.has_word_graph("text2")
+        
+        sp = spelling.ReaderCorrector(r, "text1")
+        assert_equal(sp.suggest(u"kaola", maxdist=1), [u'koala'])
+        assert_equal(sp.suggest(u"kaola", maxdist=2), [u'koala', u'kaori', u'ooala', u'zoala'])
+
+        sp = spelling.ReaderCorrector(r, "text2")
+        assert_equal(sp.suggest(u"alfo", maxdist=1), [u"alfa", u"olfo"])
+
+def test_multi():
+    from whoosh.support.dawg import flatten
     
-    from whoosh.scoring import Frequency
-    sugs = sp.suggestions_and_scores(u"alpha", weighting=Frequency())
-    assert_equal(sugs, [(u"alfa", 10, 3.0), (u"charlie", 8, 1.0)])
+    schema = fields.Schema(text=fields.TEXT(spelling=True))
+    ix = RamStorage().create_index(schema)
+    domain = u"special specious spectacular spongy spring specials".split()
+    for word in domain:
+        w = ix.writer()
+        w.add_document(text=word)
+        w.commit(merge=False)
+    
+    with ix.reader() as r:
+        assert not r.is_atomic()
+        words = list(flatten(r.word_graph("text")))
+        assert_equal(words, sorted(domain))
 
-def test_minscore():
-    st = RamStorage()
-    sp = spelling.SpellChecker(st, mingram=2, minscore=2.0)
-    
-    sp.add_words([u'charm', u'amour'])
-    
-    sugs = sp.suggest(u"armor")
-    assert_equal(sugs, [u'charm'])
+        corr = r.corrector("text")
+        assert_equal(corr.suggest("specail", maxdist=2), ["special", "specials"])
 
 
 
-