Commits

Matt Chaput  committed cbf9d80

Integrated DAWG writing/reading and suggestions into backend.

  • Participants
  • Parent commits cf4a3af
  • Branches dawg

Comments (0)

Files changed (11)

File benchmark/reuters.py

         ana = analysis.StandardAnalyzer()
         schema = fields.Schema(id=fields.ID(stored=True),
                                headline=fields.STORED,
-                               text=fields.TEXT(analyzer=ana, stored=True))
+                               text=fields.TEXT(analyzer=ana, stored=True,
+                                                spelling=True))
         return schema
     
     def zcatalog_setup(self, cat):

File src/whoosh/fields.py

     multitoken_query = "first"
     sortable_type = unicode
     sortable_typecode = None
+    spelling=False
     
     __inittypes__ = dict(format=Format, vector=Format,
                          scorable=bool, stored=bool, unique=bool)
     
     __inittypes__ = dict(stored=bool, unique=bool, field_boost=float)
     
-    def __init__(self, stored=False, unique=False, field_boost=1.0):
+    def __init__(self, stored=False, unique=False, field_boost=1.0,
+                 spelling=False):
         """
         :param stored: Whether the value of this field is stored with the document.
         """
         self.format = Existence(analyzer=IDAnalyzer(), field_boost=field_boost)
         self.stored = stored
         self.unique = unique
+        self.spelling = spelling
 
 
 class IDLIST(FieldType):
     
     __inittypes__ = dict(stored=bool, unique=bool, expression=bool, field_boost=float)
     
-    def __init__(self, stored=False, unique=False, expression=None, field_boost=1.0):
+    def __init__(self, stored=False, unique=False, expression=None,
+                 field_boost=1.0, spelling=False):
         """
         :param stored: Whether the value of this field is stored with the
             document.
         self.format = Existence(analyzer=analyzer, field_boost=field_boost)
         self.stored = stored
         self.unique = unique
+        self.spelling = spelling
 
 
 class NUMERIC(FieldType):
                          unique=bool, field_boost=float)
     
     def __init__(self, stored=False, lowercase=False, commas=False,
-                 scorable=False, unique=False, field_boost=1.0):
+                 scorable=False, unique=False, field_boost=1.0, spelling=True):
         """
         :param stored: Whether to store the value of the field with the
             document.
         self.scorable = scorable
         self.stored = stored
         self.unique = unique
+        self.spelling = spelling
 
 
 class TEXT(FieldType):
                          stored=bool, field_boost=float)
     
     def __init__(self, analyzer=None, phrase=True, vector=None, stored=False,
-                 field_boost=1.0, multitoken_query="first"):
+                 field_boost=1.0, multitoken_query="first", spelling=False):
         """
         :param analyzer: The analysis.Analyzer to use to index the field
             contents. See the analysis module for more information. If you omit
         self.multitoken_query = multitoken_query
         self.scorable = True
         self.stored = stored
+        self.spelling = spelling
 
 
 class NGRAM(FieldType):

File src/whoosh/filedb/fileindex.py

     along the way).
     """
 
-    EXTENSIONS = {"fieldlengths": "fln", "storedfields": "sto",
-                  "termsindex": "trm", "termposts": "pst",
-                  "vectorindex": "vec", "vectorposts": "vps"}
+    EXTENSIONS = {"dawg": "dag",
+                  "fieldlengths": "fln",
+                  "storedfields": "sto",
+                  "termsindex": "trm",
+                  "termposts": "pst",
+                  "vectorindex": "vec",
+                  "vectorposts": "vps"}
     
     generation = 0
     

File src/whoosh/filedb/filereading.py

                                       LengthReader, TermVectorReader)
 from whoosh.matching import FilterMatcher, ListMatcher
 from whoosh.reading import IndexReader, TermNotFound
+from whoosh.spelling import suggest
+from whoosh.support.dawg import DawgReader
 from whoosh.util import protected
 
 SAVE_BY_DEFAULT = True
         self.postfile = self.storage.open_file(segment.termposts_filename,
                                                mapped=False)
         
+        # Dawg file
+        self.dawg = None
+        if any(field.spelling for field in self.schema):
+            dawgfile = self.storage.open_file(segment.dawg_filename,
+                                              mapped=False)
+            self.dawg = DawgReader(dawgfile)
+        
         self.dc = segment.doc_count_all()
         assert self.dc == self.storedfields.length
         
         
         return FilePostingReader(self.vpostfile, offset, vformat, stringids=True)
 
+    # DAWG methods
+
+    def has_word_graph(self, fieldname):
+        if fieldname not in self.schema:
+            raise TermNotFound("No field %r" % fieldname)
+        if not self.schema[fieldname].spelling:
+            return False
+        if self.dawg:
+            return fieldname in self.dawg.fields
+
+    def terms_within(self, fieldname, word, maxdist, prefix=0):
+        if not self.has_word_graph(fieldname):
+            raise Exception("No word graph for field %r" % fieldname)
+        
+        return self.dawg.within(fieldname, word, maxdist, prefix=prefix)
+    
+    def _field_root(self, fieldname):
+        if not self.has_word_graph(fieldname):
+            raise Exception("No word graph for field %r" % fieldname)
+        
+        return self.dawg.field_root(fieldname)
+    
     # Field cache methods
 
     def supports_caches(self):

File src/whoosh/filedb/filewriting.py

                                       TermVectorWriter)
 from whoosh.filedb.pools import TempfilePool
 from whoosh.store import LockError
+from whoosh.support.dawg import DawgWriter
 from whoosh.support.filelock import try_for
 from whoosh.util import fib
 from whoosh.writing import IndexWriter, IndexingError
         # Create a temporary segment to use its .*_filename attributes
         segment = Segment(self.name, self.generation, 0, None, None)
         
+        # DAWG file
+        dawg = None
+        if any(field.spelling for field in self.schema):
+            df = self.storage.create_file(segment.dawg_filename)
+            dawg = DawgWriter(df)
+        
         # Terms index
         tf = self.storage.create_file(segment.termsindex_filename)
         ti = TermIndexWriter(tf)
         pf = self.storage.create_file(segment.termposts_filename)
         pw = FilePostingWriter(pf, blocklimit=blocklimit)
         # Terms writer
-        self.termswriter = TermsWriter(self.schema, ti, pw)
+        self.termswriter = TermsWriter(self.schema, ti, pw, dawg)
         
         if self.schema.has_vectored_fields():
             # Vector index
 
 
 class TermsWriter(object):
-    def __init__(self, schema, termsindex, postwriter, inlinelimit=1):
+    def __init__(self, schema, termsindex, postwriter, dawg, inlinelimit=1):
         self.schema = schema
         self.termsindex = termsindex
         self.postwriter = postwriter
+        self.dawg = dawg
         self.inlinelimit = inlinelimit
         
+        self.hasdawg = set(fieldname for fieldname, field in self.schema.items()
+                           if field.spelling)
         self.lastfn = None
         self.lasttext = None
         self.format = None
         self.offset = None
-    
+        
     def _new_term(self, fieldname, text):
         lastfn = self.lastfn
         lasttext = self.lasttext
             raise Exception("Postings are out of order: %r:%s .. %r:%s" %
                             (lastfn, lasttext, fieldname, text))
     
+        if fieldname in self.hasdawg:
+            self.dawg.add(fieldname, text)
+    
         if fieldname != lastfn:
             self.format = self.schema[fieldname].format
-    
+        
         if fieldname != lastfn or text != lasttext:
             self._finish_term()
             # Reset the term attributes
         self._finish_term()
         self.termsindex.close()
         self.postwriter.close()
+        if self.dawg:
+            self.dawg.close()
         
         
         

File src/whoosh/query.py

                     self.boost, self.maxdist, self.prefixlength)
 
     def __unicode__(self):
-        r = u"~" + self.text
+        r = self.text + "~"
+        if self.maxdist > 1:
+            r += "%d" % self.maxdist
         if self.boost != 1.0:
             r += "^%f" % self.boost
         return r
         termset.add((self.fieldname, self.text))
 
     def _words(self, ixreader):
-        text = self.text
-        maxdist = self.maxdist
-        for term in ixreader.expand_prefix(self.fieldname,
-                                           text[:self.prefixlength]):
-            if text == term:
-                yield term
-            elif distance(text, term) <= maxdist:
-                yield term
+        return ixreader.terms_within(self.fieldname, self.text, self.maxdist,
+                                     prefix=self.prefixlength)
 
 
 class RangeMixin(object):

File src/whoosh/reading.py

 from bisect import bisect_right
 from heapq import heapify, heapreplace, heappop, nlargest
 
+from whoosh.support.levenshtein import distance
 from whoosh.util import ClosableMixin
 from whoosh.matching import MultiMatcher
 
                 yield (vec.id(), decoder(vec.value()))
                 vec.next()
 
+    def has_word_graph(self, fieldname):
+        """Returns True if the given field has a "word graph" associated with
+        it, allowing suggestions for correcting mis-typed words and fast fuzzy
+        term searching.
+        """
+        
+        return False
+    
+    def terms_within(self, fieldname, text, maxdist, prefix=0):
+        """Returns a generator of words in the given field within ``maxdist``
+        Damerau-Levenshtein edit distance of the given text.
+        
+        :param prefix: require suggestions to share a prefix of this length
+            with the given word. This is often justifiable since most
+            misspellings do not involve the first letter of the word.
+            Using a prefix dramatically decreases the time it takes to generate
+            the list of words.
+        """
+        
+        # The default implementation uses brute force to scan the entire word
+        # list and calculate the edit distance for each word. Backends that
+        # store a word graph can override this with something more elegant.
+        
+        for word in self.expand_prefix(fieldname, text[:prefix]):
+            if word == text:
+                yield text
+            elif distance(word, text, limit=maxdist) <= maxdist:
+                yield word
+    
     def most_frequent_terms(self, fieldname, number=5, prefix=''):
         """Returns the top 'number' most frequent terms in the given field as a
         list of (frequency, text) tuples.

File src/whoosh/searching.py

 
 from whoosh import classify, highlight, query, scoring
 from whoosh.reading import TermNotFound
+from whoosh.spelling import suggest
 from whoosh.support.bitvector import BitSet, BitVector
 from whoosh.util import now, lru_cache
 
         # Copy attributes/methods from wrapped reader
         for name in ("stored_fields", "all_stored_fields", "vector", "vector_as",
                      "lexicon", "frequency", "doc_frequency", 
-                     "field_length", "doc_field_length", "max_field_length",
-                     ):
+                     "field_length", "doc_field_length", "max_field_length"):
             setattr(self, name, getattr(self.ixreader, name))
 
     def __enter__(self):
             for docnum in q.docs(self):
                 yield docnum
 
+    def suggest(self, fieldname, text, limit=5, maxdist=2, prefix=0):
+        """Returns a sorted list of suggested corrections for the given
+        mis-typed word based on the contents of the given field.
+        
+        See :meth:`whoosh.spelling.suggest` for more information.
+        """
+        
+        return suggest(self.reader(), fieldname, text, limit=limit,
+                       maxdist=maxdist, prefix=prefix)
+
     def key_terms(self, docnums, fieldname, numterms=5,
                   model=classify.Bo1Model, normalize=True):
         """Returns the 'numterms' most important terms from the documents

File src/whoosh/spelling.py

 # those of the authors and should not be interpreted as representing official
 # policies, either expressed or implied, of Matt Chaput.
 
-"""This module contains functions/classes using a Whoosh index as a backend for
-a spell-checking engine.
+"""This module contains helper functions for correcting typos in user queries.
 """
 
 from collections import defaultdict
+from heapq import heappush, heapreplace
 
 from whoosh import analysis, fields, query, scoring
-from whoosh.support.levenshtein import relative, distance
+from whoosh.support.levenshtein import distance
 
 
+def suggest(reader, fieldname, text, limit=5, maxdist=2, prefix=0):
+    """Returns a sorted list of suggested corrections for the given
+    mis-typed word based on the contents of the given field.
+    
+    This method ranks suggestions first by lowest Damerau-Levenshtein edit
+    distance, then by highest term frequency, so more common words will be
+    suggested first.
+    
+    >>> r = ix.reader()
+    >>> suggest(r, "text", "specail")
+    [u'special']
+    
+    :param reader: an object which implements the ``terms_within`` and
+        ``frequency`` methods.
+    :param fieldname: the field to use for words. This may be None if the
+        "reader" does not support fields.
+    :param limit: only return up to this many suggestions. If there are not
+        enough terms in the field within ``maxdist`` of the given word, the
+        returned list will be shorter than this number.
+    :param maxdist: the largest edit distance from the given word to look
+        at. Numbers higher than 2 are not very effective or efficient.
+    :param prefix: require suggestions to share a prefix of this length
+        with the given word. This is often justifiable since most misspellings
+        do not involve the first letter of the word. Using a prefix
+        dramatically decreases the time it takes to generate the list of words.
+    """
+    
+    heap = []
+    seen = set()
+    for k in xrange(1, maxdist+1):
+        for sug in reader.terms_within(fieldname, text, k, prefix=prefix):
+            if sug in seen:
+                continue
+            seen.add(sug)
+            
+            item = (k, 0 - reader.frequency(sug), sug)
+            if len(heap) < limit:
+                heappush(heap, item)
+            elif item < heap[0]:
+                heapreplace(heap, item)
+        
+        if len(heap) >= limit:
+            break
+    
+    print sorted(heap)
+    return [sug for _, _, sug in sorted(heap)]
+
+
+class Corrector(object):
+    """This class allows you to generate suggested corrections for mis-typed
+    words based on a word list. Note that if you want to generate suggestions
+    based on the content of a field in an index, you should turn spelling on
+    for the field and use :meth:`whoosh.searching.Searcher.suggest` instead of
+    this object.
+    
+    """
+    
+    def __init__(self):
+        pass
+    
+    
+
+
+# Old, obsolete spell checker
+
 class SpellChecker(object):
-    """Implements a spell-checking engine using a search index for the backend
+    """This feature is obsolete. Instead use either a field with spelling
+    turned on or a :class:`Corrector`.
+    
+    Implements a spell-checking engine using a search index for the backend
     storage and lookup. This class is based on the Lucene contributed spell-
     checker code.
     

File src/whoosh/support/dawg.py

         return self._edges
 
 
-class DiskNode(object):
-    caching = True
-    
-    def __init__(self, f, offset, usebytes=True):
-        self.f = f
-        self.offset = offset
-        self._edges = {}
-        
-        f.seek(offset)
-        flags = f.read_byte()
-        
-        lentype = flags & 3
-        if lentype != 0:
-            if lentype == 1:
-                count = flags >> 4
-            elif lentype == 2:
-                count = f.read_byte()
-            else:
-                count = f.read_ushort()
-            
-            for _ in xrange(count):
-                if usebytes:
-                    cnum = f.read_byte()
-                else:
-                    cnum = f.read_ushort()
-                char = unichr(cnum)
-                
-                self._edges[char] = f.read_uint()
-        
-        self.final = flags & 4
-    
-    @classmethod
-    def open(cls, dbfile):
-        dbfile.seek(0)
-        usebytes = bool(dbfile.read_int())
-        ptr = dbfile.read_uint()
-        return cls(dbfile, ptr, usebytes=usebytes)
-    
-    def __repr__(self):
-        return "<%s:%s %s>" % (self.offset, "".join(self.ptrs.keys()), self.final)
-    
-    def __contains__(self, key):
-        return key in self._edges
-    
-    def edge(self, key):
-        v = self._edges[key]
-        if not isinstance(v, DiskNode):
-            # Convert pointer to disk node
-            v = DiskNode(self.f, v)
-            #if self.caching:
-            self._edges[key] = v
-        return v
-    
-    def all_edges(self):
-        e = self.edge
-        return dict((key, e(key)) for key in self._edges.iterkeys())
-    
-    def load(self, depth=1):
-        for key in self._keys:
-            node = self.edge(key)
-            if depth:
-                node.load(depth - 1)
-
-
 class DawgWriter(object):
     def __init__(self, dbfile):
         self.dbfile = dbfile
         self.unchecked = []
         # List of unique nodes that have been checked for duplication.
         self.minimized = {}
+        
+        # Maps fieldnames to node starts
+        self.fields = {}
+        self._reset()
+        
+        dbfile.write_int(0)  # File flags
+        dbfile.write_uint(0)  # Pointer to field index
+    
+    def _reset(self):
+        self.fieldname = None
         self.root = DawgNode()
         self.offsets = {}
-        self.usebytes = True
+    
+    def add(self, fieldname, text):
+        if fieldname != self.fieldname:
+            if fieldname in self.fields:
+                raise Exception("I already wrote %r!" % fieldname)
+            if self.fieldname is not None:
+                self._write_field()
+            self.fieldname = fieldname
+        
+        self.insert(text)
     
     def insert(self, word):
         if word < self.lastword:
             node = self.unchecked[-1][2]
 
         for letter in word[prefixlen:]:
-            if ord(letter) > 255: 
-                self.usebytes = False
             nextnode = DawgNode()
             node.put(letter, nextnode)
             self.unchecked.append((node, letter, nextnode))
                 self.minimized[child] = child;
             self.unchecked.pop()
 
-    def lookup(self, word):
-        node = self.root
-        for letter in word:
-            if letter not in node._edges: return False
-            node = node._edges[letter]
-
-        return node.final
-
-    def node_count(self):
-        return len(self.minimized)
-
-    def edge_count(self):
-        count = 0
-        for node in self.minimized:
-            count += len(node._edges)
-        return count
-    
     def close(self):
-        self._minimize(0);
+        if self.fieldname is not None:
+            self._write_field()
+        dbfile = self.dbfile
         
-        dbfile = self.dbfile
-        dbfile.write_int(self.usebytes)  # File flags
-        dbfile.write_uint(0)  # Pointer
-        start = self._write(self.root)
+        self.indexpos = dbfile.tell()
+        dbfile.write_pickle(self.fields)
         dbfile.flush()
         dbfile.seek(_INT_SIZE)
-        dbfile.write_uint(start)
+        dbfile.write_uint(self.indexpos)
         dbfile.close()
     
+    def _write_field(self):
+        self._minimize(0);
+        self.fields[self.fieldname] = self._write(self.root)
+        self._reset()
+        
     def _write(self, node):
         dbfile = self.dbfile
-        keys = sorted(node._edges.keys())
+        keys = node._edges.keys()
         nkeys = len(keys)
         ptrs = []
         for key in keys:
             # Otherwise, write count as an unsigned short
             flags |= 3
         
+        if nkeys:
+            # Fourth lowest bit indicates whether the keys are 1 or 2 bytes
+            singlebytes = all(ord(key) <= 255 for key in keys)
+            flags |= singlebytes << 3
+        
         # Third lowest bit indicates whether this node ends a word
         flags |= node.final << 2
+        
         dbfile.write_byte(flags)
-        
         if nkeys:
             # If number of keys is < 16, it's stashed in the flags byte
             if nkeys >= 16 and nkeys <= 255:
                 dbfile.write_ushort(nkeys)
             
             for i in xrange(nkeys):
-                if self.usebytes:
-                    dbfile.write(keys[i])
+                charnum = ord(keys[i])
+                if singlebytes: 
+                    dbfile.write_byte(charnum)
                 else:
-                    dbfile.write_ushort(ord(keys[i]))
+                    dbfile.write_ushort(charnum)
                 dbfile.write_uint(ptrs[i])
         
         return start
 
 
-def suggest(node, word, rset, k=1, i=0, sofar="", prefix=0):
+class DiskNode(object):
+    caching = True
+    
+    def __init__(self, f, offset):
+        self.f = f
+        self.offset = offset
+        self._edges = {}
+        
+        f.seek(offset)
+        flags = f.read_byte()
+        
+        lentype = flags & 3
+        if lentype != 0:
+            if lentype == 1:
+                count = flags >> 4
+            elif lentype == 2:
+                count = f.read_byte()
+            else:
+                count = f.read_ushort()
+            
+            singlebytes = flags & 8
+            for _ in xrange(count):
+                if singlebytes:
+                    char = unichr(f.read_byte())
+                else:
+                    char = unichr(f.read_ushort())
+                
+                self._edges[char] = f.read_uint()
+        
+        self.final = flags & 4
+    
+    def __repr__(self):
+        return "<%s:%s %s>" % (self.offset, "".join(self._edges.keys()), bool(self.final))
+    
+    def __contains__(self, key):
+        return key in self._edges
+    
+    def edge(self, key):
+        v = self._edges[key]
+        if not isinstance(v, DiskNode):
+            # Convert pointer to disk node
+            v = DiskNode(self.f, v)
+            #if self.caching:
+            self._edges[key] = v
+        return v
+    
+    def all_edges(self):
+        e = self.edge
+        return dict((key, e(key)) for key in self._edges.iterkeys())
+    
+    def load(self, depth=1):
+        for key in self._keys:
+            node = self.edge(key)
+            if depth:
+                node.load(depth - 1)
+
+class DawgReader(object):
+    def __init__(self, dbfile):
+        self.dbfile = dbfile
+        
+        dbfile.seek(0)
+        self.fileflags = dbfile.read_int()
+        self.indexpos = dbfile.read_uint()
+        dbfile.seek(self.indexpos)
+        self.fields = dbfile.read_pickle()
+        
+    def field_root(self, fieldname):
+        v = self.fields[fieldname]
+        if not isinstance(v, DiskNode):
+            v = DiskNode(self.dbfile, v)
+            self.fields[fieldname] = v
+        return v
+    
+    def within(self, fieldname, text, k=1, prefix=0, seen=None):
+        if seen is None:
+            seen = set()
+        
+        node = self.field_root(fieldname)
+        sofar = ""
+        if prefix:
+            node = skip_prefix(node, text, prefix)
+            if node is None:
+                return
+            sofar, text = text[:prefix], text[prefix:]
+        
+        for sug in within(node, text, k, sofar=sofar):
+            if sug in seen:
+                continue
+            yield sug
+            seen.add(sug)
+            
+
+def within(node, word, k=1, i=0, sofar=""):
     assert k >= 0
-    if prefix:
-        node = advance_through(node, word[:prefix])
-        if node is None:
-            return
-        sofar, word = word[:prefix], word[prefix:]
     
     if i == len(word) and node.final:
-        rset.add(sofar)
+        yield sofar
     
     # Match
     if i < len(word) and word[i] in node:
-        suggest(node.edge(word[i]), word, rset, k, i + 1, sofar + word[i])
+        for w in within(node.edge(word[i]), word, k, i + 1, sofar + word[i]):
+            yield w
     
     if k > 0:
         dk = k - 1
         ii = i + 1
         edges = node.all_edges()
         # Insertions
-        for label in edges:
-            suggest(edges[label], word, rset, dk, i, sofar + label)
+        for key in edges:
+            for w in within(edges[key], word, dk, i, sofar + key):
+                yield w
         
         if i < len(word):
             char = word[i]
             if i < len(word) - 1 and word[ii] in edges:
                 second = edges[word[i+1]]
                 if char in second:
-                    suggest(second.edge(char), word, rset, dk, i + 2,
-                            sofar + word[ii] + char)
+                    for w in within(second.edge(char), word, dk, i + 2,
+                                     sofar + word[ii] + char):
+                        yield w
             
             # Deletion
-            suggest(node, word, rset, dk, ii, sofar)
+            for w in within(node, word, dk, ii, sofar):
+                yield w
             
             # Replacements
-            for label in edges:
-                if label != char:
-                    suggest(edges[label], word, rset, dk, ii, sofar + label)
+            for key in edges:
+                if key != char:
+                    for w in within(edges[key], word, dk, ii, sofar + key):
+                        yield w
 
 
-def advance_through(node, prefix):
-    for key in prefix:
+def skip_prefix(node, text, prefix):
+    for key in text[:prefix]:
         if key in node:
             node = node.edge(key)
         else:
             return None
     return node
-    
+
 
 def find_nearest(node, prefix):
     sofar = []
 def run_out(node, sofar):
     sofar = []
     while not node.final:
-        first = node.keys()[0]
+        first = min(node.keys())
         sofar.append(first)
         node = node.edge(first)
     return sofar

File src/whoosh/support/levenshtein.py

             addcost = thisrow[y - 1] + 1
             subcost = oneago[y - 1] + (seq1[x] != seq2[y])
             thisrow[y] = min(delcost, addcost, subcost)
-        if limit and thisrow[x] > limit:
-            return thisrow[x]
+        
+        if limit and x > limit and min(thisrow) > limit:
+            return limit + 1
+        
     return thisrow[len(seq2) - 1]
 
 
             if (x > 0 and y > 0 and seq1[x] == seq2[y - 1]
                 and seq1[x-1] == seq2[y] and seq1[x] != seq2[y]):
                 thisrow[y] = min(thisrow[y], twoago[y - 2] + 1)
-        if limit and thisrow[x] > limit:
-            return thisrow[x]
+        
+        if limit and x > limit and min(thisrow) > limit:
+            return limit + 1
+        
     return thisrow[len(seq2) - 1]