Source

whoosh / src / whoosh / tables.py

Full commit
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
#===============================================================================
# Copyright 2008 Matt Chaput
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#===============================================================================

"""
Generic storage classes for creating static files that support
FAST key-value (Table*) and key-value-postings (PostingTable*) storage.

These objects require that you add rows in increasing order of their
keys. They will raise an exception you try to add keys out-of-order.

These objects use a simple file format. The first 4 bytes are an unsigned
long ("!L" struct) pointing to the directory data.
The next 4 bytes are a pointer to the posting data, if any. In a table without
postings, this is 0.
Following that are N pickled objects (the blocks of rows).
Following the objects is the directory, which is a pickled list of
(key, filepos) pairs. Because the keys are pickled as part of the directory,
they can be any pickle-able object. (The keys must also be hashable because
they are used as dictionary keys. It's best to use value types for the
keys: tuples, numbers, and/or strings.)

This module also contains simple implementations for writing and reading
static "Record" files made up of fixed-length records based on the
struct module.
"""

import shutil, tempfile
from array import array
from bisect import bisect_left, bisect_right
from marshal import loads
from marshal import dumps

try:
    from zlib import compress, decompress
    has_zlib = True
except ImportError:
    has_zlib = False

from whoosh.structfile import _USHORT_SIZE, StructFile

# Utility functions

def copy_data(treader, inkey, twriter, outkey, postings = False, buffersize = 32 * 1024):
    """
    Copies the data associated with the key from the
    "reader" table to the "writer" table, along with the
    raw postings if postings = True.
    """
    
    if postings:
        offset, length, postcount, data = treader._get_plain(inkey)
        twriter.add_row(outkey, data,
                        postinginfo=(twriter.offset, length, postcount))
        
        # Copy the raw posting data
        infile = treader.table_file
        infile.seek(treader.postpos + offset)
        outfile = twriter.posting_file
        if length <= buffersize:
            outfile.write(infile.read(length))
        else:
            sofar = 0
            while sofar < length:
                readsize = min(buffersize, length - sofar)
                outfile.write(infile.read(readsize))
                sofar += readsize
        
        twriter.offset = outfile.tell()
    else:
        twriter.add_row(outkey, treader[inkey])


# Table writer classes

class TableWriter(object):
    def __init__(self, table_file, blocksize = 16 * 1024,
                 compressed = 0, prefixcoding = False,
                 postings = False, stringids = False,
                 checksize = True):
        self.table_file = table_file
        self.blocksize = blocksize
        
        if compressed > 0 and not has_zlib:
            raise Exception("zlib is not available: cannot compress table")
        self.compressed = compressed
        self.prefixcoding = prefixcoding
        
        self.haspostings = postings
        if postings:
            self.offset = 0
            self.postcount = 0
            self.lastpostid = None
            self.stringids = stringids
            self.posting_file = StructFile(tempfile.TemporaryFile())
        
        self.rowbuffer = []
        self.lastkey = None
        self.blockfilled = 0
        
        self.keys = []
        self.pointers = array("L")
        
        # Remember where we started writing
        self.start = table_file.tell()
        # Save space for a pointer to the directory
        table_file.write_ulong(0)
        # Save space for a pointer to the postings
        table_file.write_ulong(0)
        
        self.options = {"haspostings": postings,
                        "compressed": compressed,
                        "prefixcoding": prefixcoding,
                        "stringids": stringids}
    
    def close(self):
        # If there is still a block waiting to be written, flush it out
        if self.rowbuffer:
            self._write_block()
        
        tf = self.table_file
        haspostings = self.haspostings
        
        # Remember where we started writing the directory
        dirpos = tf.tell()
        # Write the directory
        tf.write_pickle(self.keys)
        tf.write_array(self.pointers)
        tf.write_pickle(self.options)
        
        if haspostings:
            # Remember where we started the postings
            postpos = tf.tell()
            # Seek back to the beginning of the postings and
            # copy them onto the end of the table file.
            self.posting_file.seek(0)
            shutil.copyfileobj(self.posting_file, tf)
            self.posting_file.close()
        
        # Seek back to where we started writing and write a
        # pointer to the directory
        tf.seek(self.start)
        tf.write_ulong(dirpos)
        
        if haspostings:
            # Write a pointer to the postings
            tf.write_ulong(postpos)
        
        tf.close()
    
    def _write_block(self):
        buf = self.rowbuffer
        key = buf[0][0]
        compressed = self.compressed
        
        self.keys.append(key)
        self.pointers.append(self.table_file.tell())
        if compressed:
            pck = dumps(buf)
            self.table_file.write_string(compress(pck, compressed))
        else:
            self.table_file.write_pickle(buf)
        
        self.rowbuffer = []
        self.blockfilled = 0
    
    def write_posting(self, id, data, writefn):
        # IDs must be added in increasing order
        if id <= self.lastpostid:
            raise IndexError("IDs must increase: %r..%r" % (self.lastpostid, id))
        
        pf = self.posting_file
        if self.stringids:
            pf.write_string(id.encode("utf8"))
        else:
            lastpostid = self.lastpostid or 0
            pf.write_varint(id - lastpostid)
        
        self.lastpostid = id
        self.postcount += 1
        
        return writefn(pf, data)
    
    def add_row(self, key, data, postinginfo=None):
        # Note: call this AFTER you add any postings!
        # Keys must be added in increasing order
        if key <= self.lastkey:
            raise IndexError("Keys must increase: %r..%r" % (self.lastkey, key))
        
        rb = self.rowbuffer
        
        if isinstance(data, array):
            self.blockfilled += len(data) * data.itemsize
        else:
            # Ugh! We're pickling twice! At least it's fast.
            self.blockfilled += len(dumps(data))
        self.lastkey = key
        
        if self.haspostings:
            endoffset = self.posting_file.tell()
            
            # Add the posting info to the stored row data
            
            # The postinginfo keyword argument allows us to copy
            # information about postings from another table.
            if postinginfo:
                offset, length, postcount = postinginfo
            else:
                offset = self.offset
                length = endoffset - self.offset
                postcount = self.postcount
            rb.append((key, (offset, length, postcount, data)))
            
            # Reset the posting variables
            self.offset = endoffset
            self.postcount = 0
            self.lastpostid = None
        else:
            rb.append((key, data))
        
        # If this row filled up a block, flush it out
        if self.blockfilled >= self.blocksize:
            #print len(rb)
            self._write_block()


# Table reader classes

class TableReader(object):
    def __init__(self, table_file):
        self.table_file = table_file
        
        # Read the pointer to the directory
        dirpos = table_file.read_ulong()
        # Read the pointer to the postings (0 if there are no postings)
        self.postpos = table_file.read_ulong()
        
        # Seek to where the directory begins and read it
        table_file.seek(dirpos)
        self.blockindex = table_file.read_pickle()
        self.blockcount = len(self.blockindex)
        self.blockpositions = table_file.read_array("L", self.blockcount)
        options = table_file.read_pickle()
        self.__dict__.update(options)
        
        if self.compressed > 0 and not has_zlib:
            raise Exception("zlib is not available: cannot decompress table")
        
        # Initialize cached block
        self.currentblock = None
        self.itemlist = None
        self.itemdict = None
        
        if self.haspostings:
            if self.stringids:
                self._read_id = self._read_id_string
            else:
                self._read_id = self._read_id_varint
            self.get = self._get_ignore_postinfo
        else:
            self.get = self._get_plain
    
    def __contains__(self, key):
        if key < self.blockindex[0]:
            return False
        self._load_block(key)
        return key in self.itemdict
    
    def _get_ignore_postinfo(self, key):
        self._load_block(key)
        return self.itemdict[key][3]
    
    def _get_plain(self, key):
        self._load_block(key)
        return self.itemdict[key]
    
    def __iter__(self):
        if self.haspostings:
            for i in xrange(0, self.blockcount):
                self._load_block_num(i)
                for key, value in self.itemlist:
                    yield (key, value[3])
        else:
            for i in xrange(0, self.blockcount):
                self._load_block_num(i)
                for key, value in self.itemlist:
                    yield (key, value)
    
    def _read_id_varint(self, lastid):
        return lastid + self.table_file.read_varint()
    
    def _read_id_string(self, lastid):
        return self.table_file.read_string().decode("utf8")
    
    def iter_from(self, key):
        postings = self.haspostings
        
        self._load_block(key)
        blockcount = self.blockcount
        itemlist = self.itemlist
        
        p = bisect_left(itemlist, (key, None))
        if p >= len(itemlist):
            if self.currentblock >= blockcount - 1:
                return
            self._load_block_num(self.currentblock + 1)
            itemlist = self.itemlist
            p = 0
        
        # Yield the rest of the rows
        while True:
            kv = itemlist[p]
            if postings:
                yield (kv[0], kv[1][3])
            else:
                yield kv
            
            p += 1
            if p >= len(itemlist):
                if self.currentblock >= blockcount - 1:
                    return
                self._load_block_num(self.currentblock + 1)
                itemlist = self.itemlist
                p = 0
    
    def close(self):
        self.table_file.close()
    
    def keys(self):
        return (key for key, _ in self)
    
    def values(self):
        return (value for _, value in self)
    
    def posting_count(self, key):
        if not self.haspostings: raise Exception("This table does not have postings")
        return self._get_plain(key)[2]
    
    def postings(self, key, readfn):
        postfile = self.table_file
        _read_id = self._read_id
        id = 0
        for _ in xrange(0, self._seek_postings(key)):
            id = _read_id(id)
            yield (id, readfn(postfile))
    
    def _load_block_num(self, bn):
        blockcount = len(self.blockindex)
        if bn < 0 or bn >= blockcount:
            raise ValueError("Block number %s/%s" % (bn, blockcount))
        
        pos = self.blockpositions[bn]
        self.table_file.seek(pos)
        
        # Sooooooo sloooooow...
        if self.compressed:
            pck = self.table_file.read_string()
            itemlist = loads(decompress(pck))
        else:
            itemlist = self.table_file.read_pickle()
        
        self.itemlist = itemlist
        self.itemdict = dict(itemlist)
        self.currentblock = bn
        self.minkey = itemlist[0][0]
        self.maxkey = itemlist[-1][0]
    
    def _load_block(self, key):
        if self.currentblock is None or key < self.minkey or key > self.maxkey:
            bn = max(0, bisect_right(self.blockindex, key) - 1)
            self._load_block_num(bn)

    def _seek_postings(self, key):
        offset, length, count = self._get_plain(key)[:3] #@UnusedVariable
        self.table_file.seek(self.postpos + offset)
        return count


# An array table only stores numeric arrays and does not support postings.

class ArrayWriter(object):
    def __init__(self, table_file, typecode, bufferlength=4*1024):
        if typecode not in table_file._type_writers:
            raise Exception("Can't (yet) write an array table of type %r" % typecode)
        
        self.table_file = table_file
        self.typecode = typecode
        self.bufferlength = bufferlength
        self.dir = {}
        self.buffer = array(typecode)
        
        # Remember where we started writing
        self.start = table_file.tell()
        # Save space for a pointer to the directory
        table_file.write_ulong(0)
    
    def _flush(self):
        buff = self.buffer
        if buff:
            self.table_file.write_array(buff)
        self.buffer = array(self.typecode)
    
    def close(self):
        self._flush()
        tf = self.table_file
        
        # Remember where we started writing the directory
        dirpos = tf.tell()
        # Write the directory
        tf.write_pickle((self.typecode, self.dir))
        
        # Seek back to where we started writing and write a
        # pointer to the directory
        tf.seek(self.start)
        tf.write_ulong(dirpos)
        
        tf.close()
        
    def add_row(self, key, values = None):
        self._flush()
        self.dir[key] = self.table_file.tell()
        if values:
            self.extend(values)
        
    def append(self, value):
        buff = self.buffer
        buff.append(value)
        if len(buff) > self.bufferlength:
            self._flush()
            
    def extend(self, values):
        buff = self.buffer
        buff.extend(values)
        if len(buff) > self.bufferlength:
            self._flush()
            
    def from_file(self, fobj):
        self._flush()
        shutil.copyfileobj(fobj, self.table_file)


class ArrayReader(object):
    def __init__(self, table_file):
        self.table_file = table_file
        
        # Read the pointer to the directory
        dirpos = table_file.read_ulong()
        # Seek to where the directory begins and read it
        table_file.seek(dirpos)
        typecode, self.dir = table_file.read_pickle()
        
        # Set the "read()" method of this object to the appropriate
        # read method of the underlying StructFile for the table's
        # data type.
        try:
            self.read = self.table_file._type_readers[typecode]
        except KeyError:
            raise Exception("Can't (yet) read an array table of type %r" % self.typecode)
        
        self.typecode = typecode
        self.itemsize = array(typecode).itemsize
    
    def __contains__(self, key):
        return key in self.dir
    
    def get(self, key, offset):
        tf = self.table_file
        pos = self.dir[key]
        tf.seek(pos + offset * self.itemsize)
        return self.read()
    
    def close(self):
        self.table_file.close()
        
    def to_file(self, key, fobj):
        raise NotImplementedError


class RecordWriter(object):
    def __init__(self, table_file, typecode, length):
        self.table_file = table_file
        self.typecode = typecode
        self.length = length
        
        table_file.write(typecode[0])
        table_file.write_ushort(length)
    
    def close(self):
        self.table_file.close()
        
    def append(self, arry):
        assert arry.typecode == self.typecode
        assert len(arry) == self.length
        self.table_file.write_array(arry)
        

class RecordReader(object):
    def __init__(self, table_file):
        self.table_file = table_file
        self.typecode = table_file.read(1)
        
        try:
            self.read = self.table_file._type_readers[self.typecode]
        except KeyError:
            raise Exception("Can't (yet) read an array table of type %r" % self.typecode)
        
        self.length = table_file.read_ushort()
        self.itemsize = array(self.typecode).itemsize
        self.recordsize = self.length * self.itemsize
    
    def close(self):
        self.table_file.close()
    
    def get(self, recordnum, itemnum):
        assert itemnum < self.length
        self.table_file.seek(1 + _USHORT_SIZE +\
                             recordnum * self.recordsize +\
                             itemnum * self.itemsize)
        return self.read()
    
    def get_record(self, recordnum):
        tf = self.table_file
        tf.seek(1 + _USHORT_SIZE + recordnum * self.recordsize)
        return tf.read_array(self.typecode, self.length)


class StringListWriter(object):
    def __init__(self, table_file, listlength):
        self.table_file = table_file
        self.listlength = listlength
        self.positions = array("L")
        
        table_file.write_ulong(0)
    
    def close(self):
        tf = self.table_file
        directory_pos = tf.tell()
        tf.write_array(self.positions)
        tf.seek(0)
        tf.write_ulong(directory_pos)
        tf.close()
    
    def append(self, ustrings):
        assert len(ustrings) == self.listlength
        tf = self.table_file
        
        self.positions.append(tf.tell())
        
        encoded = [ustring.encode("utf8") for ustring in ustrings]
        lenarray = array("I", (len(s) for s in encoded))
        tf.write_array(lenarray)
        tf.write("".join(encoded))
        

class StringListReader(object):
    def __init__(self, table_file, listlength, size):
        self.table_file = table_file
        self.listlength = listlength
        self.size = size
        
        self.positions = table_file.read_array("L", size)
        
    def close(self):
        self.table_file.close()
    
    def get(self, num):
        tf = self.table_file
        listlength = self.listlength
        
        tf.seek(self.positions[num])
        lens = tf.read_array("I", listlength)
        string = tf.read(sum(lens))
        
        p = 0
        decoded = []
        for ln in lens:
            decoded.append(string[p:p+ln].decode("utf8"))
            p += ln
        return decoded


if __name__ == '__main__':
    pass