Matt Chaput avatar Matt Chaput committed 11d0f6f

Switched hash file to new more flexible format, switched back to cdb_hash, added backwards compatibility.

Fixes issue #131.

Comments (0)

Files changed (1)

src/whoosh/filedb/filetables.py

 
 _4GB = 4 * 1024 * 1024 * 1024
 
-#def cdb_hash(key):
-#    h = 5381L
-#    for c in key:
-#        h = (h + (h << 5)) & 0xffffffffL ^ ord(c)
-#    return h
+def cdb_hash(key):
+    h = 5381L
+    for c in key:
+        h = (h + (h << 5)) & 0xffffffffL ^ ord(c)
+    return h
 
 _header_entry_struct = Struct("!qI")  # Position, number of slots
 header_entry_size = _header_entry_struct.size
 pack_lengths = _lengths_struct.pack
 unpack_lengths = _lengths_struct.unpack
 
-_pointer_struct = Struct("!qq")  # Hash value, position
-pointer_size = _pointer_struct.size
-pack_pointer = _pointer_struct.pack
-unpack_pointer = _pointer_struct.unpack
-
-HEADER_SIZE = 256 * header_entry_size
-
-#def _hash(value):
-#    return abs(hash(value))
-_hash = hash
-
 
 # Table classes
 
 class HashWriter(object):
-    def __init__(self, dbfile):
+    def __init__(self, dbfile, old_format=False):
         self.dbfile = dbfile
-        # Seek past the first 2048 bytes of the file... we'll come back here
-        # to write the header later
-        dbfile.seek(HEADER_SIZE)
+        self.old_format = old_format
+        
+        if not old_format:
+            self.header_size = 16 + 256 * header_entry_size
+            _pointer_struct = Struct("!Iq")  # Hash value, position
+            self.hash_func = cdb_hash
+        else:
+            self.header_size = 256 * header_entry_size
+            _pointer_struct = Struct("!qq")  # Hash value, position
+            self.hash_func = hash
+        
+        self.pointer_size = _pointer_struct.size
+        self.pack_pointer = _pointer_struct.pack
+        
+        # Seek past the first "header_size" bytes of the file... we'll come
+        # back here to write the header later
+        dbfile.seek(self.header_size)
         # Store the directory of hashed values
         self.hashes = defaultdict(list)
 
     def add_all(self, items):
         dbfile = self.dbfile
+        hash_func = self.hash_func
         hashes = self.hashes
         pos = dbfile.tell()
         write = dbfile.write
             write(key)
             write(value)
 
-            h = _hash(key)
+            h = hash_func(key)
             hashes[h & 255].append((h, pos))
             pos += lengths_size + len(key) + len(value)
 
 
             write = dbfile.write
             for hashval, position in hashtable:
-                write(pack_pointer(hashval, position))
-                pos += pointer_size
+                write(self.pack_pointer(hashval, position))
+                pos += self.pointer_size
 
         dbfile.flush()
+        self._end_of_hashes = dbfile.tell()
 
     def _write_directory(self):
         dbfile = self.dbfile
         directory = self.directory
 
         dbfile.seek(0)
+        if not self.old_format:
+            dbfile.write("HASH")
+            dbfile.write_int(0)  # Unused
+            dbfile.write_long(self._end_of_hashes)
+        
         for position, numslots in directory:
             dbfile.write(pack_header_entry(position, numslots))
-        assert dbfile.tell() == HEADER_SIZE
+        
         dbfile.flush()
+        assert dbfile.tell() == self.header_size
 
     def close(self):
         self._write_hashes()
     def __init__(self, dbfile):
         self.dbfile = dbfile
         self.map = dbfile.map
-        self.end_of_data = dbfile.get_long(0)
+        
+        dbfile.seek(0)
+        magic = dbfile.read(4)
+        if magic == "HASH":
+            self.old_format = False
+            self.header_size = 16 + 256 * header_entry_size
+            _pointer_struct = Struct("!Iq")  # Hash value, position
+            dbfile.read_int()  # Unused
+            self._end_of_hashes = dbfile.read_long()
+            assert self._end_of_hashes >= self.header_size, "%s < %s" % (self._end_of_hashes, self.header_size)
+            self.hash_func = cdb_hash
+        else:
+            self.old_format = True
+            self.header_size = 256 * header_entry_size
+            _pointer_struct = Struct("!qq")  # Hash value, position
+            self.hash_func = hash
+        
+        self.buckets = []
+        for _ in xrange(256):
+            he = unpack_header_entry(dbfile.read(header_entry_size))
+            self.buckets.append(he)
+        self._start_of_hashes = self.buckets[0][0]
+        
+        self.pointer_size = _pointer_struct.size
+        self.unpack_pointer = _pointer_struct.unpack
+
         self.is_closed = False
 
     def close(self):
     def read(self, position, length):
         return self.map[position:position + length]
 
-    def _ranges(self, pos=HEADER_SIZE):
-        eod = self.end_of_data
+    def _ranges(self, pos=None):
+        if pos is None:
+            pos = self.header_size
+        eod = self._start_of_hashes
         read = self.read
         while pos < eod:
             keylen, datalen = unpack_lengths(read(pos, lengths_size))
 
     def _hashtable_info(self, keyhash):
         # Return (directory_position, number_of_hash_entries)
-        return unpack_header_entry(self.read((keyhash & 255) * header_entry_size,
-                                             header_entry_size))
+        return self.buckets[keyhash & 255]
 
     def _key_position(self, key):
-        keyhash = _hash(key)
+        keyhash = self.hash_func(key)
         hpos, hslots = self._hashtable_info(keyhash)
         if not hslots:
             raise KeyError(key)
 
     def _get_ranges(self, key):
         read = self.read
-        keyhash = _hash(key)
+        pointer_size = self.pointer_size
+        keyhash = self.hash_func(key)
         hpos, hslots = self._hashtable_info(keyhash)
         if not hslots:
             return
 
         slotpos = hpos + (((keyhash >> 8) % hslots) * pointer_size)
         for _ in xrange(hslots):
-            slothash, pos = unpack_pointer(read(slotpos, pointer_size))
+            slothash, pos = self.unpack_pointer(read(slotpos, pointer_size))
             if not pos:
                 return
 
                         yield (pos + lengths_size + keylen, datalen)
                         
     def end_of_hashes(self):
-        lastpos, lastnum = unpack_header_entry(self.read(255 * header_entry_size,
-                                                         header_entry_size))
-        return lastpos + lastnum * pointer_size
+        if self.old_format:
+            lastpos, lastnum = self.buckets[255]
+            return lastpos + lastnum * self.pointer_size
+        else:
+            return self._end_of_hashes
 
 
 class OrderedHashWriter(HashWriter):
     def add_all(self, items):
         dbfile = self.dbfile
         hashes = self.hashes
+        hash_func = self.hash_func
         pos = dbfile.tell()
         write = dbfile.write
 
             write(key)
             write(value)
 
-            h = _hash(key)
+            h = hash_func(key)
             hashes[h & 255].append((h, pos))
             
             pos += lengths_size + len(key) + len(value)
 
 # Utility functions
 
-def dump_hash(hashreader):
-    dbfile = hashreader.dbfile
-    read = hashreader.read
-    eod = hashreader.end_of_data
+#def dump_hash(hashreader):
+#    dbfile = hashreader.dbfile
+#    read = hashreader.read
+#    eod = hashreader._start_of_hashes
+#
+#    print "HEADER_SIZE=", hashreader.header_size, "eod=", eod
+#
+#    # Dump hashtables
+#    for bucketnum in xrange(256):
+#        pos, numslots = unpack_header_entry(read(bucketnum * header_entry_size, header_entry_size))
+#        if numslots:
+#            print "Bucket %d: %d slots" % (bucketnum, numslots)
+#
+#            dbfile.seek(pos)
+#            for _ in xrange(0, numslots):
+#                print "  %X : %d" % hashreader.unpack_pointer(read(pos, pointer_size))
+#                pos += pointer_size
+#        else:
+#            print "Bucket %d empty: %s, %s" % (bucketnum, pos, numslots)
+#
+#    # Dump keys and values
+#    print "-----"
+#    pos = hashreader.header_size
+#    dbfile.seek(pos)
+#    while pos < eod:
+#        keylen, datalen = unpack_lengths(read(pos, lengths_size))
+#        keypos = pos + lengths_size
+#        datapos = pos + lengths_size + keylen
+#        key = read(keypos, keylen)
+#        data = read(datapos, datalen)
+#        print "%d +%d,%d:%r->%r" % (pos, keylen, datalen, key, data)
+#        pos = datapos + datalen
 
-    print "HEADER_SIZE=", HEADER_SIZE, "eod=", eod
 
-    # Dump hashtables
-    for bucketnum in xrange(0, 256):
-        pos, numslots = unpack_header_entry(read(bucketnum * header_entry_size, header_entry_size))
-        if numslots:
-            print "Bucket %d: %d slots" % (bucketnum, numslots)
-
-            dbfile.seek(pos)
-            for _ in xrange(0, numslots):
-                print "  %X : %d" % unpack_pointer(read(pos, pointer_size))
-                pos += pointer_size
-        else:
-            print "Bucket %d empty: %s, %s" % (bucketnum, pos, numslots)
-
-    # Dump keys and values
-    print "-----"
-    pos = HEADER_SIZE
-    dbfile.seek(pos)
-    while pos < eod:
-        keylen, datalen = unpack_lengths(read(pos, lengths_size))
-        keypos = pos + lengths_size
-        datapos = pos + lengths_size + keylen
-        key = read(keypos, keylen)
-        data = read(datapos, datalen)
-        print "%d +%d,%d:%r->%r" % (pos, keylen, datalen, key, data)
-        pos = datapos + datalen
-
-
-##
-#
-#class FixedHashWriter(HashWriter):
-#    def __init__(self, dbfile, keysize, datasize):
-#        self.dbfile = dbfile
-#        dbfile.seek(HEADER_SIZE)
-#        self.hashes = defaultdict(list)
-#        self.keysize = keysize
-#        self.datasize = datasize
-#        self.recordsize = keysize + datasize
-#
-#    def add_all(self, items):
-#        dbfile = self.dbfile
-#        hashes = self.hashes
-#        recordsize = self.recordsize
-#        pos = dbfile.tell()
-#        write = dbfile.write
-#
-#        for key, value in items:
-#            write(key + value)
-#
-#            h = _hash(key)
-#            hashes[h & 255].append((h, pos))
-#            pos += recordsize
-#
-#
-#class FixedHashReader(HashReader):
-#    def __init__(self, dbfile, keysize, datasize):
-#        self.dbfile = dbfile
-#        self.keysize = keysize
-#        self.datasize = datasize
-#        self.recordsize = keysize + datasize
-#        
-#        self.map = dbfile.map
-#        self.end_of_data = dbfile.get_uint(0)
-#        self.is_closed = False
-#
-#    def read(self, position, length):
-#        return self.map[position:position + length]
-#
-#    def _ranges(self, pos=HEADER_SIZE):
-#        keysize = self.keysize
-#        recordsize = self.recordsize
-#        eod = self.end_of_data
-#        while pos < eod:
-#            yield (pos, pos + keysize)
-#            pos += recordsize
-#
-#    def __iter__(self):
-#        return self.items()
-#
-#    def __contains__(self, key):
-#        for _ in self._get_data_poses(key):
-#            return True
-#        return False
-#
-#    def items(self):
-#        keysize = self.keysize
-#        datasize = self.datasize
-#        read = self.read
-#        for keypos, datapos in self._ranges():
-#            yield (read(keypos, keysize), read(datapos, datasize))
-#
-#    def keys(self):
-#        keysize = self.keysize
-#        read = self.read
-#        for keypos, _ in self._ranges():
-#            yield read(keypos, keysize)
-#
-#    def values(self):
-#        datasize = self.datasize
-#        read = self.read
-#        for _, datapos in self._ranges():
-#            yield read(datapos, datasize)
-#
-#    def __getitem__(self, key):
-#        for data in self.all(key):
-#            return data
-#        raise KeyError(key)
-#
-#    def get(self, key, default=None):
-#        for data in self.all(key):
-#            return data
-#        return default
-#
-#    def all(self, key):
-#        datasize = self.datasize
-#        read = self.read
-#        for datapos in self._get_data_poses(key):
-#            yield read(datapos, datasize)
-#
-#    def _key_at(self, pos):
-#        return self.read(pos, self.keysize)
-#
-#    def _get_ranges(self, key):
-#        raise NotImplementedError
-#
-#    def _get_data_poses(self, key):
-#        keysize = self.keysize
-#        read = self.read
-#        keyhash = _hash(key)
-#        hpos, hslots = self._hashtable_info(keyhash)
-#        if not hslots:
-#            return
-#
-#        slotpos = hpos + (((keyhash >> 8) % hslots) * pointer_size)
-#        for _ in xrange(hslots):
-#            slothash, pos = unpack_pointer(read(slotpos, pointer_size))
-#            if not pos:
-#                return
-#
-#            slotpos += pointer_size
-#            # If we reach the end of the hashtable, wrap around
-#            if slotpos == hpos + (hslots * pointer_size):
-#                slotpos = hpos
-#
-#            if slothash == keyhash:
-#                if key == read(pos, keysize):
-#                    yield pos + keysize
-
-
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.