Commits

Matt Chaput committed 27170cf

Cleanups and docs.

Comments (0)

Files changed (5)

src/whoosh/filedb/filereading.py

                                       LengthReader, TermVectorReader)
 from whoosh.matching import FilterMatcher, ListMatcher
 from whoosh.reading import IndexReader, TermNotFound
-from whoosh.support.dawg import DawgReader
+from whoosh.support.dawg import DiskNode
 from whoosh.util import protected
 
 SAVE_BY_DEFAULT = True
             fname = segment.dawg_filename
             if self.storage.file_exists(fname):
                 dawgfile = self.storage.open_file(fname, mapped=False)
-                self.dawg = DawgReader(dawgfile, expand=False).root
+                self.dawg = DiskNode.load(dawgfile, expand=False)
         
         self.dc = segment.doc_count_all()
         assert self.dc == self.storedfields.length

src/whoosh/filedb/filewriting.py

                                       TermVectorWriter)
 from whoosh.filedb.pools import TempfilePool
 from whoosh.store import LockError
-from whoosh.support.dawg import DawgWriter
+from whoosh.support.dawg import DawgBuilder
 from whoosh.support.filelock import try_for
 from whoosh.util import fib
 from whoosh.writing import IndexWriter, IndexingError
         dawg = None
         if any(field.spelling for field in self.schema):
             df = self.storage.create_file(segment.dawg_filename)
-            dawg = DawgWriter(df, reduce_root=False)
+            dawg = DawgBuilder(df, reduce_root=False)
         
         # Terms index
         tf = self.storage.create_file(segment.termsindex_filename)
         filename = segment.dawg_filename
         r = SegmentReader(storage, schema, segment)
         f = storage.create_file(filename)
-        dw = DawgWriter(reduce_root=False)
+        dawg = DawgBuilder(reduce_root=False)
         for fieldname in fieldnames:
             ft = (fieldname, )
             for word in r.lexicon(fieldname):
-                dw.insert(ft + tuple(word))
-        dw.write(f)
+                dawg.insert(ft + tuple(word))
+        dawg.write(f)
     
     for fieldname in fieldnames:
         schema[fieldname].spelling = True

src/whoosh/spelling.py

     
     @classmethod
     def from_word_list(cls, wordlist, ranking=None, fieldname=""):
-        dw = dawg.DawgWriter()
+        dw = dawg.DawgBuilder()
         for word in wordlist:
             dw.insert(word)
         return cls(dw.root, ranking=ranking)
     
     @classmethod
     def from_graph_file(cls, dbfile, ranking=None, fieldname=""):
-        dr = dawg.DawgReader(dbfile)
+        dr = dawg.DiskNode.load(dbfile)
         return cls(dr.root, ranking=ranking)
     
 
     
     from whoosh.filedb.structfile import StructFile
     
-    dw = dawg.DawgWriter()
+    dw = dawg.DawgBuilder()
     for word in wordlist:
         dw.insert(word)
     

src/whoosh/support/dawg.py

 # those of the authors and should not be interpreted as representing official
 # policies, either expressed or implied, of Matt Chaput.
 
+"""
+This module contains classes and functions for working with Directed Acyclic
+Word Graphs (DAWGs). This structure is used to efficiently store a list of
+words.
+
+This code should be considered an implementation detail and may change in
+future releases.
+
+TODO: try to find a way to traverse the term index efficiently to do within()
+instead of storing a DAWG separately.
+"""
+
+import re
 from array import array
 
 from whoosh.system import _INT_SIZE
       given label.
       
     * ``__iter__()`` returns an iterator of the labels for the node's outgoing
-      edges.
+      edges. ``keys()`` is available as a convenient shortcut to get a list.
       
     * ``__len__()`` returns the number of outgoing edges.
     
         raise NotImplementedError
     
     def keys(self):
+        """Returns a list of the outgoing edge labels.
+        """
+        
         return list(self)
     
     def edge(self, key, expand=True):
+        """Returns the node connected to the outgoing edge with the given label.
+        """
+        
         raise NotImplementedError
     
     def all_edges(self):
+        """Returns a dictionary mapping outgoing edge labels to nodes.
+        """
+        
         e = self.edge
         return dict((key, e(key)) for key in self)
     
     def edge_count(self):
+        """Returns the recursive count of edges in this node and the tree under
+        it.
+        """
+        
         return len(self) + sum(self.edge(key).edge_count() for key in self)
 
 
 
 
 class BuildNode(object):
+    """Node type used by DawgBuilder when constructing a graph from scratch.
+    """
+    
     def __init__(self):
         self.final = False
         self._edges = {}
         self._hash = None
 
     def __repr__(self):
-        return "<%s:%s %s>" % (self.id, "".join(self._edges.keys()), self.final)
+        return "<%s:%s %s>" % (self.__class__.__name__, ",".join(self._edges.keys()), self.final)
 
     def __hash__(self):
         if self._hash is not None:
         return self._edges[key]
     
     def all_edges(self):
-        return self._dict
+        return self._edges
     
 
-class DawgWriter(object):
+class DawgBuilder(object):
+    """Class for building a graph from scratch.
+    
+    >>> db = DawgBuilder()
+    >>> db.insert(u"alfa")
+    >>> db.insert(u"bravo")
+    >>> db.write(dbfile)
+    
+    This class does not have the cleanest API, because it was cobbled together
+    to support the spelling correction system.
+    """
+    
     def __init__(self, dbfile=None, reduced=True, reduce_root=True):
+        """
+        :param dbfile: an optional StructFile. If you pass this argument to the
+            initializer, you don't have to pass a file to the ``write()``
+            method after you construct the graph.
+        :param reduced: when the graph is finished, branches of single-edged
+            nodes will be collapsed into single nodes to form a Patricia tree.
+        :param reduce_root: when ``reduce`` is True and this argument is True,
+            reduction will include the root node. If the root node edges are
+            special (as in an index segment's term DAWG, where it has the field
+            names), you can turn this off to keep the root edges "safe" from
+            reduction.
+        """
+        
         self.dbfile = dbfile
         self.reduced = reduced
         self.reduce_root = reduce_root
         self.offsets = {}
     
     def insert(self, word):
+        """Add the given "word" (a string or list of strings) to the graph.
+        Words must be inserted in sorted order.
+        """
+        
         if word < self.lastword:
             raise Exception("Out of order %r..%r." % (self.lastword, word))
 
                 self.minimized[child] = child;
             self.unchecked.pop()
 
-    def write(self, dbfile=None):
-        dbfile = self.dbfile or dbfile
-        dbfile.write("GR01")  # Magic number
-        dbfile.write_int(0)  # File flags
-        dbfile.write_uint(0)  # Pointer to root node
+    def finish(self):
+        """Minimize the graph by merging duplicates, and reduce branches of
+        single-edged nodes. You can call this explicitly if you are building
+        a graph to use in memory. Otherwise it is automatically called by
+        the write() method.
+        """
         
         self._minimize(0)
         root = self.root
                 for key in root:
                     v = root.edge(key)
                     reduce(v)
-        offset = self._write_node(dbfile, root)
+
+    def write(self, dbfile=None):
+        """Write the graph to the given StructFile. If you passed a file to
+        the initializer, you don't have to pass it here.
+        """
+        
+        dbfile = self.dbfile or dbfile
+        dbfile.write("GR01")  # Magic number
+        dbfile.write_int(0)  # File flags
+        dbfile.write_uint(0)  # Pointer to root node
+        
+        self.finish()
+        offset = self._write_node(dbfile, self.root)
         
         # Seek back and write the pointer to the root node
         dbfile.flush()
 
 
 class DiskNode(BaseNode):
-    def __init__(self, dr, offset, expand=True):
-        self.dr = dr
+    def __init__(self, dbfile, offset, expand=True):
         self.id = offset
+        self.dbfile = dbfile
         
-        dbfile = dr.dbfile
         dbfile.seek(offset)
         flags = dbfile.read_byte()
         self.final = bool(flags & 1)
                 else:
                     key = utf8decode(dbfile.read_string())[0]
                     if len(key) > 1 and expand:
-                        self._edges[key[0]] = PatNode(dr, key[1:], ptr)
+                        self._edges[key[0]] = PatNode(dbfile, key[1:], ptr)
                     else:
                         self._edges[key] = ptr
     
     def __repr__(self):
-        return "<%s %s:%s %s>" % (self.__class__.__name__, self.id, ",".join(self._edges.keys()), self.final)
+        return "<%s %s:%s %s>" % (self.__class__.__name__, self.id,
+                                  ",".join(self._edges.keys()), self.final)
     
     def __contains__(self, key):
         return key in self._edges
         v = self._edges[key]
         if not isinstance(v, BaseNode):
             # Convert pointer to disk node
-            v = DiskNode(self.dr, v, expand=expand)
+            v = DiskNode(self.dbfile, v, expand=expand)
             #if self.caching:
             self._edges[key] = v
         return v
     
+    @classmethod
+    def load(cls, dbfile, expand=True):
+        dbfile.seek(0)
+        magic = dbfile.read(4)
+        assert magic == "GR01"
+        fileflags = dbfile.read_int()
+        return DiskNode(dbfile, dbfile.read_uint(), expand=expand)
+    
 
 class PatNode(BaseNode):
     final = False
     
-    def __init__(self, dr, label, nextptr, i=0):
-        self.dr = dr
+    def __init__(self, dbfile, label, nextptr, i=0):
+        self.dbfile = dbfile
         self.label = label
         self.nextptr = nextptr
         self.i = i
         if i < len(label) and key == label[i]:
             i += 1
             if i < len(self.label):
-                return PatNode(self.dr, label, self.nextptr, i)
+                return PatNode(self.dbfile, label, self.nextptr, i)
             else:
-                return DiskNode(self.dr, self.nextptr)
+                return DiskNode(self.dbfile, self.nextptr)
         else:
             raise KeyError(key)
         
     def edge_count(self):
-        return DiskNode(self.dr, self.nextptr).edge_count()
+        return DiskNode(self.dbfile, self.nextptr).edge_count()
 
 
 class ComboNode(BaseNode):
             return IntersectionNode(a.edge(key), b.edge(key))
 
 
-# Reader for disk-based graph files
-
-class DawgReader(object):
-    def __init__(self, dbfile, expand=True):
-        self.dbfile = dbfile
-        
-        dbfile.seek(0)
-        magic = dbfile.read(4)
-        assert magic == "GR01"
-        self.fileflags = dbfile.read_int()
-        self.root = DiskNode(self, dbfile.read_uint(), expand=expand)
-
-
 # Functions
 
 def reduce(node):

tests/test_spelling.py

 from __future__ import with_statement
 from nose.tools import assert_equal, assert_not_equal
 
+import whoosh.support.dawg as dawg
 from whoosh import fields, spelling
 from whoosh.filedb.filestore import RamStorage
-from whoosh.support.dawg import flatten
 from whoosh.support.testing import TempStorage
 
 
         assert_equal(sp.suggest(u"alfo", maxdist=1), [u"alfa", u"olfo"])
 
 def test_dawg():
-    from whoosh.support.dawg import DawgWriter
+    from whoosh.support.dawg import DawgBuilder
     
     with TempStorage() as st:
         df = st.create_file("test.dawg")
         
-        dw = DawgWriter(reduce_root=False)
+        dw = DawgBuilder(reduce_root=False)
         dw.insert(["test"] + list("special"))
         dw.insert(["test"] + list("specials"))
         dw.write(df)
         
-        assert_equal(list(flatten(dw.root.edge("test"))), ["special", "specials"])
+        assert_equal(list(dawg.flatten(dw.root.edge("test"))), ["special", "specials"])
     
 
 def test_multi():
     
     with ix.reader() as r:
         assert not r.is_atomic()
-        words = list(flatten(r.word_graph("text")))
+        words = list(dawg.flatten(r.word_graph("text")))
         assert_equal(words, sorted(domain))
 
         corr = r.corrector("text")
         
         from whoosh.support.dawg import dump_dawg
         dump_dawg(r.word_graph("text"))
-        words = list(flatten(r.word_graph("text")))
+        words = list(dawg.flatten(r.word_graph("text")))
         assert_equal(words, sorted(domain))
 
         corr = r.corrector("text")