1. Matt Chaput
  2. whoosh

Commits

Matt Chaput  committed 96f6e09

Added "FieldedOrderedHash" implementation.
If no length is passed to init, seek to end of file to get the length.
Added tell() method, added file() method to get underlying file.

  • Participants
  • Parent commits cfcef02
  • Branches default

Comments (0)

Files changed (1)

File src/whoosh/filedb/filetables.py

View file
  • Ignore whitespace
 D. J. Bernstein's CDB format (http://cr.yp.to/cdb.html).
 """
 
-import struct
+import os, struct
 from binascii import crc32
+from bisect import bisect_left
 from hashlib import md5  # @UnresolvedImport
 
 from whoosh.compat import b, bytes_type
         # List to remember the positions of the hash tables
         self.directory = []
 
+    def tell(self):
+        return self.dbfile.tell()
+
     def add(self, key, value):
         """Adds a key/value pair to the file. Note that keys DO NOT need to be
         unique. You can store multiple values under the same key and retrieve
     :class:`HashWriter`.
     """
 
-    def __init__(self, dbfile, length, magic=b("HSH3"), startoffset=0):
+    def __init__(self, dbfile, length=None, magic=b("HSH3"), startoffset=0):
         """
         :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object
             to read from.
         self.startoffset = startoffset
         self.is_closed = False
 
+        if length is None:
+            dbfile.seek(0, os.SEEK_END)
+            length = dbfile.tell() - startoffset
+
         dbfile.seek(startoffset)
         # Check format tag
         filemagic = dbfile.read(4)
         dbfile = storage.open_file(name)
         return cls(dbfile, length)
 
+    def file(self):
+        return self.dbfile
+
     def _read_extras(self):
         try:
             self.extras = self.dbfile.read_pickle()
         self.dbfile.close()
         self.is_closed = True
 
-    def _ranges(self, pos=None):
+    def _key_at(self, pos):
+        # Returns the key bytes at the given position
+
+        dbfile = self.dbfile
+        keylen = dbfile.get_uint(pos)
+        return dbfile.get(pos + _lengths.size, keylen)
+
+    def _ranges(self, pos=None, eod=None):
         # Yields a series of (keypos, keylength, datapos, datalength) tuples
         # for the key/value pairs in the file
         dbfile = self.dbfile
         pos = pos or self.startofdata
-        eod = self.endofdata
+        eod = eod or self.endofdata
         lenssize = _lengths.size
         unpacklens = _lengths.unpack
 
         index.to_file(dbfile)
 
 
-class OrderedHashReader(HashReader):
-    def _read_extras(self):
-        dbfile = self.dbfile
-
-        # Read the extras
-        HashReader._read_extras(self)
-
-        # Set up for reading the index array
-        indextype = self.extras["indextype"]
-        self.indexbase = dbfile.tell()
-        self.indexlen = self.extras["indexlen"]
-        self.indexsize = struct.calcsize(indextype)
-        # Set up the function to read values from the index array
-        if indextype == "B":
-            self._get_pos = dbfile.get_byte
-        elif indextype == "H":
-            self._get_pos = dbfile.get_ushort
-        elif indextype == "i":
-            self._get_pos = dbfile.get_int
-        elif indextype == "I":
-            self._get_pos = dbfile.get_uint
-        elif indextype == "q":
-            self._get_pos = dbfile.get_long
-        else:
-            raise Exception("Unknown index type %r" % indextype)
-
-    def _key_at(self, pos):
-        # Returns the key bytes at the given position
-
-        dbfile = self.dbfile
-        keylen = dbfile.get_uint(pos)
-        return dbfile.get(pos + _lengths.size, keylen)
-
-    def _closest_key_pos(self, key):
-        # Given a key, return the position of that key OR the next highest key
-        # if the given key does not exist
-        if not isinstance(key, bytes_type):
-            raise TypeError("Key %r should be bytes" % key)
-
-        indexbase = self.indexbase
-        indexsize = self.indexsize
-        _key_at = self._key_at
-        _get_pos = self._get_pos
-
-        # Do a binary search of the positions in the index array
-        lo = 0
-        hi = self.indexlen
-        while lo < hi:
-            mid = (lo + hi) // 2
-            midkey = _key_at(_get_pos(indexbase + mid * indexsize))
-            if midkey < key:
-                lo = mid + 1
-            else:
-                hi = mid
-
-        # If we went off the end, return None
-        if lo == self.indexlen:
-            return None
-        # Return the closest key
-        return _get_pos(indexbase + lo * indexsize)
-
+class OrderedBase(HashReader):
     def closest_key(self, key):
         """Returns the closest key equal to or greater than the given key. If
         there is no key in the file equal to or greater than the given key,
         dbfile = self.dbfile
         for keypos, keylen, datapos, datalen in self.ranges_from(key):
             yield (dbfile.get(keypos, keylen), dbfile.get(datapos, datalen))
+
+
+class OrderedHashReader(OrderedBase):
+    def _read_extras(self):
+        dbfile = self.dbfile
+
+        # Read the extras
+        HashReader._read_extras(self)
+
+        # Set up for reading the index array
+        indextype = self.extras["indextype"]
+        self.indexbase = dbfile.tell()
+        self.indexlen = self.extras["indexlen"]
+        self.indexsize = struct.calcsize(indextype)
+        # Set up the function to read values from the index array
+        if indextype == "B":
+            self._get_pos = dbfile.get_byte
+        elif indextype == "H":
+            self._get_pos = dbfile.get_ushort
+        elif indextype == "i":
+            self._get_pos = dbfile.get_int
+        elif indextype == "I":
+            self._get_pos = dbfile.get_uint
+        elif indextype == "q":
+            self._get_pos = dbfile.get_long
+        else:
+            raise Exception("Unknown index type %r" % indextype)
+
+    def _closest_key_pos(self, key):
+        # Given a key, return the position of that key OR the next highest key
+        # if the given key does not exist
+        if not isinstance(key, bytes_type):
+            raise TypeError("Key %r should be bytes" % key)
+
+        indexbase = self.indexbase
+        indexsize = self.indexsize
+        _key_at = self._key_at
+        _get_pos = self._get_pos
+
+        # Do a binary search of the positions in the index array
+        lo = 0
+        hi = self.indexlen
+        while lo < hi:
+            mid = (lo + hi) // 2
+            midkey = _key_at(_get_pos(indexbase + mid * indexsize))
+            if midkey < key:
+                lo = mid + 1
+            else:
+                hi = mid
+
+        # If we went off the end, return None
+        if lo == self.indexlen:
+            return None
+        # Return the closest key
+        return _get_pos(indexbase + lo * indexsize)
+
+
+# Fielded Ordered hash file
+
+class FieldedOrderedHashWriter(HashWriter):
+    """Implements an on-disk hash, but writes separate position indexes for
+    each field.
+    """
+
+    def __init__(self, dbfile):
+        HashWriter.__init__(self, dbfile)
+        # Map field names to (startpos, indexpos, length, typecode)
+        self.fieldmap = self.extras["fieldmap"] = {}
+
+        # Keep track of the last key added
+        self.lastkey = emptybytes
+
+    def start_field(self, fieldname):
+        self.fieldstart = self.dbfile.tell()
+        self.fieldname = fieldname
+        # Keep an array of the positions of all keys
+        self.poses = GrowableArray("H")
+        self.lastkey = emptybytes
+
+    def add(self, key, value):
+        if key <= self.lastkey:
+            raise ValueError("Keys must increase: %r..%r"
+                             % (self.lastkey, key))
+        self.poses.append(self.dbfile.tell() - self.fieldstart)
+        HashWriter.add(self, key, value)
+        self.lastkey = key
+
+    def end_field(self):
+        dbfile = self.dbfile
+        fieldname = self.fieldname
+        poses = self.poses
+        self.fieldmap[fieldname] = (self.fieldstart, dbfile.tell(), len(poses),
+                                    poses.typecode)
+        poses.to_file(dbfile)
+
+
+class FieldedOrderedHashReader(HashReader):
+    def __init__(self, *args, **kwargs):
+        HashReader.__init__(self, *args, **kwargs)
+        self.fieldmap = self.extras["fieldmap"]
+        # Make a sorted list of the field names with their start and end ranges
+        self.fieldlist = []
+        for fieldname in sorted(self.fieldmap.keys()):
+            startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname]
+            self.fieldlist.append((fieldname, startpos, ixpos))
+
+    def fielded_ranges(self, pos=None, eod=None):
+        flist = self.fieldlist
+        fpos = 0
+        fieldname, start, end = flist[fpos]
+        for keypos, keylen, datapos, datalen in self._ranges(pos, eod):
+            if keypos >= end:
+                fpos += 1
+                fieldname, start, end = flist[fpos]
+            yield fieldname, keypos, keylen, datapos, datalen
+
+    def iter_terms(self):
+        get = self.dbfile.get
+        for fieldname, keypos, keylen, _, _ in self.fielded_ranges():
+            yield fieldname, get(keypos, keylen)
+
+    def iter_term_items(self):
+        get = self.dbfile.get
+        for item in self.fielded_ranges():
+            fieldname, keypos, keylen, datapos, datalen = item
+            yield fieldname, get(keypos, keylen), get(datapos, datalen)
+
+    def contains_term(self, fieldname, btext):
+        try:
+            x = self.range_for_term(fieldname, btext)
+            return True
+        except KeyError:
+            return False
+
+    def range_for_term(self, fieldname, btext):
+        start, ixpos, ixsize, code = self.fieldmap[fieldname]
+        for datapos, datalen in self.ranges_for_key(btext):
+            if start < datapos < ixpos:
+                return datapos, datalen
+        raise KeyError((fieldname, btext))
+
+    def term_data(self, fieldname, btext):
+        datapos, datalen = self.range_for_term(fieldname, btext)
+        return self.dbfile.get(datapos, datalen)
+
+    def term_get(self, fieldname, btext, default=None):
+        try:
+            return self.term_data(fieldname, btext)
+        except KeyError:
+            return default
+
+    def _closest_term_pos(self, fieldname, key):
+        # Given a key, return the position of that key OR the next highest key
+        # if the given key does not exist
+        if not isinstance(key, bytes_type):
+            raise TypeError("Key %r should be bytes" % key)
+
+        dbfile = self.dbfile
+        _key_at = self._key_at
+        startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname]
+
+        if ixtype == "B":
+            get_pos = dbfile.get_byte
+        elif ixtype == "H":
+            get_pos = dbfile.get_ushort
+        elif ixtype == "i":
+            get_pos = dbfile.get_int
+        elif ixtype == "I":
+            get_pos = dbfile.get_uint
+        elif ixtype == "q":
+            get_pos = dbfile.get_long
+        else:
+            raise Exception("Unknown index type %r" % ixtype)
+
+        # Do a binary search of the positions in the index array
+        lo = 0
+        hi = ixsize
+        while lo < hi:
+            mid = (lo + hi) // 2
+            midkey = _key_at(startpos + get_pos(ixpos + mid * ixsize))
+            if midkey < key:
+                lo = mid + 1
+            else:
+                hi = mid
+
+        # If we went off the end, return None
+        if lo == ixsize:
+            return None
+        # Return the closest key
+        return startpos + get_pos(ixpos + lo * ixsize)
+
+    def closest_term(self, fieldname, btext):
+        pos = self._closest_term_pos(fieldname, btext)
+        if pos is None:
+            return None
+        return self._key_at(pos)
+
+    def term_ranges_from(self, fieldname, btext):
+        pos = self._closest_term_pos(fieldname, btext)
+        if pos is None:
+            return
+
+        startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname]
+        for item in self._ranges(pos, ixpos):
+            yield item
+
+    def terms_from(self, fieldname, btext):
+        dbfile = self.dbfile
+        for keypos, keylen, _, _ in self.term_ranges_from(fieldname, btext):
+            yield dbfile.get(keypos, keylen)
+
+    def term_items_from(self, fieldname, btext):
+        dbfile = self.dbfile
+        for item in self.term_ranges_from(fieldname, btext):
+            keypos, keylen, datapos, datalen = item
+            yield (dbfile.get(keypos, keylen), dbfile.get(datapos, datalen))
+
+
+