Commits

Matt Chaput  committed debdee2

Changed on-disk representation of DAWG to radix tree.

  • Participants
  • Parent commits 713315d
  • Branches dawg

Comments (0)

Files changed (1)

File src/whoosh/support/dawg.py

 # those of the authors and should not be interpreted as representing official
 # policies, either expressed or implied, of Matt Chaput.
 
+from array import array
 
 from whoosh.system import _INT_SIZE
+from whoosh.util import utf8encode, utf8decode
 
 
 class BaseNode(object):
     
     * ``final`` is a property which is True if this node represents the end of
       a word.
+      
     * ``__contains__(label)`` returns True if the node has an edge with the
-      given
-      label.
+      given label.
+      
     * ``__iter__()`` returns an iterator of the labels for the node's outgoing
       edges.
+      
+    * ``__len__()`` returns the number of outgoing edges.
+    
     * ``edge(label)`` returns the Node connected to the edge with the given
       label.
+      
     * ``all_edges()`` returns a dictionary of the node's outgoing edges, where
       the keys are the edge labels and the values are the connected nodes.
     """
     def __iter__(self):
         raise NotImplementedError
     
+    def __len__(self):
+        raise NotImplementedError
+    
     def edge(self, key):
         raise NotImplementedError
     
     def all_edges(self):
         e = self.edge
         return dict((key, e(key)) for key in self)
+    
+    def edge_count(self):
+        return len(self) + sum(self.edge(key).edge_count() for key in self)
 
 
-class DawgNode(object):
+class BuildNode(object):
     def __init__(self):
         self.final = False
         self._edges = {}
     def __iter__(self):
         return iter(self._edges)
     
+    def __len__(self):
+        return len(self._edges)
+    
     def put(self, key, node):
         self._hash = None  # Invalidate the cached hash value
         self._edges[key] = node
     
     def all_edges(self):
         return self._dict
-
+    
 
 class DawgWriter(object):
-    def __init__(self, dbfile):
+    def __init__(self, dbfile, reduced=True):
         self.dbfile = dbfile
+        self.reduced = reduced
+        
         self.lastword = ""
         # List of nodes that have not been checked for duplication.
         self.unchecked = []
     
     def _reset(self):
         self.fieldname = None
-        self.root = DawgNode()
+        self.root = BuildNode()
         self.offsets = {}
     
     def add(self, fieldname, text):
     
     def insert(self, word):
         if word < self.lastword:
-            raise Exception("Error: Words must be inserted in alphabetical " +
-                "order.")
+            raise Exception("Out of order %r..%r." % (self.lastword, word))
 
         # find common prefix between word and previous word
         prefixlen = 0
 
         # Add the suffix, starting from the correct node mid-way through the
         # graph
-        if len(self.unchecked) == 0:
+        if not self.unchecked:
             node = self.root
         else:
             node = self.unchecked[-1][2]
 
         for letter in word[prefixlen:]:
-            nextnode = DawgNode()
+            nextnode = BuildNode()
             node.put(letter, nextnode)
             self.unchecked.append((node, letter, nextnode))
             node = nextnode
         dbfile.close()
     
     def _write_field(self):
-        self._minimize(0);
-        self.fields[self.fieldname] = self._write(self.root)
+        self._minimize(0)
+        root = self.root
+        if self.reduced:
+            reduce(root)
+        self.fields[self.fieldname] = self._write_node(root)
         self._reset()
         
-    def _write(self, node):
+    def _write_node(self, node):
         dbfile = self.dbfile
         keys = node._edges.keys()
-        nkeys = len(keys)
-        ptrs = []
+        ptrs = array("I")
         for key in keys:
             sn = node._edges[key]
             if id(sn) in self.offsets:
                 ptrs.append(self.offsets[id(sn)])
             else:
-                ptr = self._write(sn)
+                ptr = self._write_node(sn)
                 self.offsets[id(sn)] = ptr
                 ptrs.append(ptr)
         
         start = dbfile.tell()
         
-        # The low two bits of the flags byte indicate how the number of edges
-        # is written
-        flags = 0
-        if nkeys == 0:
-            # No outbound edges, no edge count will be written
-            pass
-        elif nkeys < 16:
-            # Count is < 16, store it in the upper 4 bits of the flags byte
-            flags |= 1 | (nkeys << 4)
-        elif nkeys < 255:
-            # Count is < 255, write as a byte
-            flags |= 2
-        else:
-            # Otherwise, write count as an unsigned short
-            flags |= 3
+        # The low bit indicates whether this node represents the end of a word
+        flags = int(node.final)
+        # The second lowest bit = whether this node has children
+        flags |= bool(keys) << 1
+        # The third lowest bit = whether all keys are single chars
+        singles = all(len(k) == 1 for k in keys)
+        flags |= singles << 2
+        # The fourth lowest bit = whether all keys are one byte
+        if singles:
+            bytes = all(ord(key) <= 255 for key in keys)
+            flags |= bytes << 3
+        dbfile.write_byte(flags)
         
-        if nkeys:
-            # Fourth lowest bit indicates whether the keys are 1 or 2 bytes
-            singlebytes = all(ord(key) <= 255 for key in keys)
-            flags |= singlebytes << 3
-        
-        # Third lowest bit indicates whether this node ends a word
-        flags |= node.final << 2
-        
-        dbfile.write_byte(flags)
-        if nkeys:
-            # If number of keys is < 16, it's stashed in the flags byte
-            if nkeys >= 16 and nkeys <= 255:
-                dbfile.write_byte(nkeys)
-            elif nkeys > 255:
-                dbfile.write_ushort(nkeys)
-            
-            for i in xrange(nkeys):
-                charnum = ord(keys[i])
-                if singlebytes: 
-                    dbfile.write_byte(charnum)
-                else:
-                    dbfile.write_ushort(charnum)
-                dbfile.write_uint(ptrs[i])
+        if keys:
+            dbfile.write_varint(len(keys))
+            dbfile.write_array(ptrs)
+            if singles:
+                for key in keys:
+                    o = ord(key)
+                    if bytes:
+                        dbfile.write_byte(o)
+                    else:
+                        dbfile.write_ushort(o)
+            else:
+                for key in keys:
+                    dbfile.write_string(utf8encode(key)[0])
         
         return start
 
 
 class DiskNode(BaseNode):
-    caching = True
+    def __init__(self, dr, offset):
+        self.dr = dr
+        self.id = offset
+        
+        dbfile = dr.dbfile
+        dbfile.seek(offset)
+        flags = dbfile.read_byte()
+        self.final = bool(flags & 1)
+        self._edges = {}
+        if flags & 2:
+            singles = flags & 4
+            bytes = flags & 8
+            
+            nkeys = dbfile.read_varint()
+            
+            ptrs = dbfile.read_array("I", nkeys)
+            for i in xrange(nkeys):
+                ptr = ptrs[i]
+                if singles:
+                    if bytes:
+                        charnum = dbfile.read_byte()
+                    else:
+                        charnum = dbfile.read_ushort()
+                    self._edges[unichr(charnum)] = ptr
+                else:
+                    key = utf8decode(dbfile.read_string())[0]
+                    if len(key) > 1:
+                        self._edges[key[0]] = PatNode(dr, key[1:], ptr)
+                    else:
+                        self._edges[key] = ptr
+            
+    def __repr__(self):
+        return "<%s:%s %s>" % (self.id, "".join(self._edges.keys()), self.final)
     
-    def __init__(self, f, offset):
-        self.f = f
-        self.offset = offset
-        self._edges = {}
-        
-        f.seek(offset)
-        flags = f.read_byte()
-        
-        lentype = flags & 3
-        if lentype != 0:
-            if lentype == 1:
-                count = flags >> 4
-            elif lentype == 2:
-                count = f.read_byte()
-            else:
-                count = f.read_ushort()
-            
-            singlebytes = flags & 8
-            for _ in xrange(count):
-                if singlebytes:
-                    char = unichr(f.read_byte())
-                else:
-                    char = unichr(f.read_ushort())
-                
-                self._edges[char] = f.read_uint()
-        
-        self.final = flags & 4
+    def __contains__(self, key):
+        return key in self._edges
     
-    def __repr__(self):
-        return "<%s:%s %s>" % (self.offset, "".join(self._edges.keys()), bool(self.final))
+    def __iter__(self):
+        return iter(self._edges)
+    
+    def __len__(self):
+        return len(self._edges)
     
     def edge(self, key):
         v = self._edges[key]
-        if not isinstance(v, DiskNode):
+        if not isinstance(v, BaseNode):
             # Convert pointer to disk node
-            v = DiskNode(self.f, v)
+            v = DiskNode(self.dr, v)
             #if self.caching:
             self._edges[key] = v
         return v
     
-    def load(self, depth=1):
-        for key in self._keys:
-            node = self.edge(key)
-            if depth:
-                node.load(depth - 1)
+
+class PatNode(BaseNode):
+    final = False
+    
+    def __init__(self, dr, label, nextptr, i=0):
+        self.dr = dr
+        self.label = label
+        self.nextptr = nextptr
+        self.i = i
+    
+    def __repr__(self):
+        return "<%r(%d) %s>" % (self.label, self.i, self.final)
+    
+    def __contains__(self, key):
+        if self.i < len(self.label) and key == self.label[self.i]:
+            return True
+        else:
+            return False
+    
+    def __iter__(self):
+        if self.i < len(self.label):
+            return iter(self.label[self.i])
+        else:
+            return []
+    
+    def __len__(self):
+        if self.i < len(self.label):
+            return 1
+        else:
+            return 0
+    
+    def edge(self, key):
+        label = self.label
+        i = self.i
+        if i < len(label) and key == label[i]:
+            i += 1
+            if i < len(self.label):
+                return PatNode(self.dr, label, self.nextptr, i)
+            else:
+                return DiskNode(self.dr, self.nextptr)
+        else:
+            raise KeyError(key)
+        
+    def edge_count(self):
+        return DiskNode(self.dr, self.nextptr).edge_count()
 
 
 class ComboNode(BaseNode):
     def __iter__(self):
         return iter(set(self.a) | set(self.b))
     
+    def __len__(self):
+        return len(set(self.a) | set(self.b))
+    
     @property
     def final(self):
         return self.a.final or self.b.final
-
+    
 
 class PriorityNode(ComboNode):
     def edge(self, key):
             return self.a.edge(key)
         else:
             return self.b.edge(key)
-
+        
 
 class MixedNode(ComboNode):
     def edge(self, key):
             return a.edge(key)
         else:
             return b.edge(key)
-    
+        
 
 class DawgReader(object):
     def __init__(self, dbfile):
         
     def field_root(self, fieldname):
         v = self.fields[fieldname]
-        if not isinstance(v, DiskNode):
-            v = DiskNode(self.dbfile, v)
+        if not isinstance(v, BaseNode):
+            v = DiskNode(self, v)
             self.fields[fieldname] = v
         return v
 
 
 # Functions
 
+def reduce(node):
+    edges = node._edges
+    if edges:
+        for key, sn in edges.items():
+            reduce(sn)
+            if len(sn) == 1:
+                skey, ssn = sn._edges.items()[0]
+                if sn.final == ssn.final or (ssn.final and len(ssn) == 0):
+                    del edges[key]
+                    edges[key + skey] = ssn
+                
+
+def edge_count(node):
+    c = len(node)
+    return c + sum(edge_count(node.edge(key)) for key in node)
+
+
+def dump_dawg(node, tab=0):
+    print "  " * tab, id(node), node.final
+    for key in node:
+        print "  " * tab, key, ":"
+        dump_dawg(node.edge(key), tab + 1)
+
+
 def within(node, text, k=1, prefix=0, seen=None):
     if seen is None:
         seen = set()