Source

lupyne / lupyne / engine / queries.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""
Query wrappers and search utilities.
"""

from future_builtins import filter, map
import itertools
import bisect
import heapq
import threading
import lucene
try:
    from java.lang import Integer
    from java.util import Arrays, HashSet
    from org.apache.lucene import document, index, search, util
    from org.apache.lucene.search import highlight, spans, vectorhighlight
    from org.apache.pylucene import search as search_
    from org.apache.pylucene.queryParser import PythonQueryParser
except ImportError:
    from lucene import Integer, Arrays, HashSet, PythonQueryParser
    document = index = search = util = highlight = spans = vectorhighlight = search_ = lucene

class Query(object):
    """Inherited lucene Query, with dynamic base class acquisition.
    Uses class methods and operator overloading for convenient query construction.
    """
    def __new__(cls, base, *args):
        return base.__new__(type(base.__name__, (cls, base), {}))
    def __init__(self, base, *args):
        base.__init__(self, *args)
    def filter(self, cache=True):
        "Return lucene CachingWrapperFilter, optionally just QueryWrapperFilter."
        if isinstance(self, search.PrefixQuery):
            filter = search.PrefixFilter(self.getPrefix())
        elif isinstance(self, search.TermRangeQuery):
            filter = search.TermRangeFilter(self.field, self.lowerTerm, self.upperTerm, self.includesLower(), self.includesUpper())
        elif isinstance(self, search.TermQuery):
            filter = search.TermsFilter()
            filter.addTerm(self.getTerm())
        else:
            filter = search.QueryWrapperFilter(self)
        return search.CachingWrapperFilter(filter) if cache else filter
    def terms(self):
        "Generate set of query term items."
        terms = HashSet().of_(index.Term)
        self.extractTerms(terms)
        for term in terms:
            yield term.field(), term.text()
    @classmethod
    def term(cls, name, value, boost=1.0):
        "Return lucene TermQuery."
        self = cls(search.TermQuery, index.Term(name, value))
        self.boost = boost
        return self
    @classmethod
    def boolean(cls, occur, *queries, **terms):
        self = BooleanQuery(search.BooleanQuery)
        for query in queries:
            self.add(query, occur)
        for name, values in terms.items():
            for value in ([values] if isinstance(values, basestring) else values):
                self.add(cls.term(name, value), occur)
        return self
    @classmethod
    def any(cls, *queries, **terms):
        "Return `BooleanQuery`_ (OR) from queries and terms."
        return cls.boolean(search.BooleanClause.Occur.SHOULD, *queries, **terms)
    @classmethod
    def all(cls, *queries, **terms):
        "Return `BooleanQuery`_ (AND) from queries and terms."
        return cls.boolean(search.BooleanClause.Occur.MUST, *queries, **terms)
    @classmethod
    def disjunct(cls, multiplier, *queries, **terms):
        "Return lucene DisjunctionMaxQuery from queries and terms."
        self = cls(search.DisjunctionMaxQuery, Arrays.asList(queries), multiplier)
        for name, values in terms.items():
            for value in ([values] if isinstance(values, basestring) else values):
                self.add(cls.term(name, value))
        return self
    @classmethod
    def span(cls, *term):
        "Return `SpanQuery`_ from term name and value or a MultiTermQuery."
        if len(term) <= 1:
            return SpanQuery(spans.SpanMultiTermQueryWrapper, *term)
        return SpanQuery(spans.SpanTermQuery, index.Term(*term))
    @classmethod
    def near(cls, name, *values, **kwargs):
        """Return :meth:`SpanNearQuery <SpanQuery.near>` from terms.
        Term values which supply another field name will be masked."""
        spans = (cls.span(name, value) if isinstance(value, basestring) else cls.span(*value).mask(name) for value in values)
        return SpanQuery.near(*spans, **kwargs)
    @classmethod
    def prefix(cls, name, value):
        "Return lucene PrefixQuery."
        return cls(search.PrefixQuery, index.Term(name, value))
    @classmethod
    def range(cls, name, start, stop, lower=True, upper=False):
        "Return lucene RangeQuery, by default with a half-open interval."
        return cls(search.TermRangeQuery, name, start, stop, lower, upper)
    @classmethod
    def phrase(cls, name, *values):
        "Return lucene PhraseQuery.  None may be used as a placeholder."
        self = cls(search.PhraseQuery)
        for idx, value in enumerate(values):
            if value is not None:
                self.add(index.Term(name, value), idx)
        return self
    @classmethod
    def multiphrase(cls, name, *values):
        "Return lucene MultiPhraseQuery.  None may be used as a placeholder."
        self = cls(search.MultiPhraseQuery)
        for idx, words in enumerate(values):
            if isinstance(words, basestring):
                words = [words]
            if words is not None:
                self.add([index.Term(name, word) for word in words], idx)
        return self
    @classmethod
    def wildcard(cls, name, value):
        "Return lucene WildcardQuery."
        return cls(search.WildcardQuery, index.Term(name, value))
    @classmethod
    def fuzzy(cls, name, value, minimumSimilarity=0.5, prefixLength=0):
        "Return lucene FuzzyQuery."
        return cls(search.FuzzyQuery, index.Term(name, value), minimumSimilarity, prefixLength)
    def __pos__(self):
        return Query.all(self)
    def __neg__(self):
        return Query.boolean(search.BooleanClause.Occur.MUST_NOT, self)
    def __and__(self, other):
        return Query.all(self, other)
    def __rand__(self, other):
        return Query.all(other, self)
    def __or__(self, other):
        return Query.any(self, other)
    def __ror__(self, other):
        return Query.any(other, self)
    def __sub__(self, other):
        return Query.any(self).__isub__(other)
    def __rsub__(self, other):
        return Query.any(other).__isub__(self)

class BooleanQuery(Query):
    "Inherited lucene BooleanQuery with sequence interface to clauses."
    def __len__(self):
        return len(self.getClauses())
    def __iter__(self):
        return iter(self.getClauses())
    def __getitem__(self, index):
        return self.getClauses()[index]
    def __iand__(self, other):
        self.add(other, search.BooleanClause.Occur.MUST)
        return self
    def __ior__(self, other):
        self.add(other, search.BooleanClause.Occur.SHOULD)
        return self
    def __isub__(self, other):
        self.add(other, search.BooleanClause.Occur.MUST_NOT)
        return self

class SpanQuery(Query):
    "Inherited lucene SpanQuery with additional span constructors."
    def filter(self, cache=True):
        "Return lucene CachingSpanFilter, optionally just SpanQueryFilter."
        filter = search.SpanQueryFilter(self)
        return search.CachingSpanFilter(filter) if cache else filter
    def __getitem__(self, slc):
        start, stop, step = slc.indices(Integer.MAX_VALUE)
        assert step == 1, 'slice step is not supported'
        return SpanQuery(spans.SpanPositionRangeQuery, self, start, stop)
    def __sub__(self, other):
        return SpanQuery(spans.SpanNotQuery, self, other)
    def __or__(*spans_):
        return SpanQuery(spans.SpanOrQuery, spans_)
    def near(*spans_, **kwargs):
        """Return lucene SpanNearQuery from SpanQueries.
        
        :param slop: default 0
        :param inOrder: default True
        :param collectPayloads: default True
        """
        args = map(kwargs.get, ('slop', 'inOrder', 'collectPayloads'), (0, True, True))
        return SpanQuery(spans.SpanNearQuery, spans_, *args)
    def mask(self, name):
        "Return lucene FieldMaskingSpanQuery, which allows combining span queries from different fields."
        return SpanQuery(spans.FieldMaskingSpanQuery, self, name)
    def payload(self, *values):
        "Return lucene SpanPayloadCheckQuery from payload values."
        base = spans.SpanNearPayloadCheckQuery if spans.SpanNearQuery.instance_(self) else spans.SpanPayloadCheckQuery
        return SpanQuery(base, self, Arrays.asList(list(map(lucene.JArray_byte, values))))

class TermsFilter(search.CachingWrapperFilter):
    """Caching filter based on a unique field and set of matching values.
    Optimized for many terms and docs, with support for incremental updates.
    Suitable for searching external metadata associated with indexed identifiers.
    Call :meth:`refresh` to cache a new (or reopened) reader.
    
    :param field: field name
    :param values: initial term values, synchronized with the cached filters
    """
    ops = {'or': 'update', 'and': 'intersection_update', 'andNot': 'difference_update'}
    def __init__(self, field, values=()):
        assert lucene.VERSION >= '3.5', 'requires FixedBitSet set operations introduced in lucene 3.5'
        search.CachingWrapperFilter.__init__(self, search.TermsFilter())
        self.field = field
        self.values = set(values)
        self.readers = set()
        self.lock = threading.Lock()
    def filter(self, values, cache=True):
        "Return lucene TermsFilter, optionally using the FieldCache."
        if cache:
            return search.FieldCacheTermsFilter(self.field, tuple(values))
        filter, term = search.TermsFilter(), index.Term(self.field)
        for value in values:
            filter.addTerm(term.createTerm(value))
        return filter
    def apply(self, filter, op, readers):
        for reader in readers:
            bitset = util.FixedBitSet.cast_(self.getDocIdSet(reader))
            getattr(bitset, op)(filter.getDocIdSet(reader).iterator())
    def update(self, values, op='or', cache=True):
        """Update allowed values and corresponding cached bitsets.
        
        :param values: additional term values
        :param op: set operation used to combine terms and docs: *and*, *or*, *andNot*
        :param cache: optionally cache all term values using FieldCache
        """
        values = tuple(values)
        filter = self.filter(values, cache)
        with self.lock:
            getattr(self.values, self.ops[op])(values)
            self.apply(filter, op, self.readers)
    def refresh(self, reader):
        "Refresh cached bitsets of current values for new segments of top-level reader."
        readers = set(reader.sequentialSubReaders)
        with self.lock:
            self.apply(self.filter(self.values), 'or', readers - self.readers)
            self.readers = set(reader for reader in readers | self.readers if reader.refCount)
    def add(self, *values):
        "Add a few term values."
        self.update(values, cache=False)
    def discard(self, *values):
        "Discard a few term values."
        self.update(values, op='andNot', cache=False)

class Comparator(object):
    "Chained arrays with bisection lookup."
    def __init__(self, arrays):
        self.arrays = list(arrays)
        self.offsets = [0]
        for array in self.arrays:
            self.offsets.append(len(self) + len(array))
    def __len__(self):
        return self.offsets[-1]
    def __iter__(self):
        return itertools.chain(*self.arrays)
    def __getitem__(self, index):
        point = bisect.bisect_right(self.offsets, index) - 1
        return self.arrays[point][index - self.offsets[point]]

class SortField(search.SortField):
    """Inherited lucene SortField used for caching FieldCache parsers.
    
    :param name: field name
    :param type: type object or name compatible with SortField constants
    :param parser: lucene FieldCache.Parser or callable applied to field values
    :param reverse: reverse flag used with sort
    """
    def __init__(self, name, type='string', parser=None, reverse=False):
        type = self.typename = getattr(type, '__name__', type).capitalize()
        if parser is None:
            parser = getattr(search.SortField, type.upper())
        elif not search.FieldCache.Parser.instance_(parser):
            base = getattr(search_, 'Python{0}Parser'.format(type))
            namespace = {'parse' + type: staticmethod(parser)}
            parser = object.__class__(base.__name__, (base,), namespace)()
        search.SortField.__init__(self, name, parser, reverse)
    def array(self, reader):
        method = getattr(search.FieldCache.DEFAULT, 'get{0}s'.format(self.typename))
        return method(reader, self.field, *[self.parser][:bool(self.parser)])
    def comparator(self, reader):
        "Return indexed values from default FieldCache using the given top-level reader."
        readers = reader.sequentialSubReaders
        if index.MultiReader.instance_(reader):
            readers = itertools.chain.from_iterable(reader.sequentialSubReaders for reader in readers)
        arrays = list(map(self.array, readers))
        return arrays[0] if len(arrays) <= 1 else Comparator(arrays)
    def filter(self, start, stop, lower=True, upper=False):
        "Return lucene FieldCacheRangeFilter based on field and type."
        method = getattr(search.FieldCacheRangeFilter, 'new{0}Range'.format(self.typename))
        return method(self.field, self.parser, start, stop, lower, upper)
    def terms(self, filter, *readers):
        "Generate field cache terms from docs which match filter from all segments."
        for reader in readers:
            array, docset = self.array(reader), filter.getDocIdSet(reader)
            for id in iter(docset.iterator().nextDoc, search.DocIdSetIterator.NO_MORE_DOCS):
                yield array[id]

class Highlighter(highlight.Highlighter):
    """Inherited lucene Highlighter with stored analysis options.
    
    :param searcher: `IndexSearcher`_ used for analysis, scoring, and optionally text retrieval
    :param query: lucene Query
    :param field: field name of text
    :param terms: highlight any matching term in query regardless of position
    :param fields: highlight matching terms from any field
    :param tag: optional html tag name
    :param formatter: optional lucene Formatter
    :param encoder: optional lucene Encoder
    """
    def __init__(self, searcher, query, field, terms=False, fields=False, tag='', formatter=None, encoder=None):
        if tag:
            formatter = highlight.SimpleHTMLFormatter('<{0}>'.format(tag), '</{0}>'.format(tag))
        scorer = (highlight.QueryTermScorer if terms else highlight.QueryScorer)(query, *(searcher.indexReader, field) * (not fields))
        highlight.Highlighter.__init__(self, *filter(None, [formatter, encoder, scorer]))
        self.searcher, self.field = searcher, field
        self.selector = document.MapFieldSelector([field])
    def fragments(self, doc, count=1):
        """Return highlighted text fragments.
        
        :param doc: text string or doc id to be highlighted
        :param count: maximum number of fragments
        """
        if not isinstance(doc, basestring):
            doc = self.searcher.doc(doc, self.selector)[self.field]
        return doc and list(self.getBestFragments(self.searcher.analyzer, self.field, doc, count))

class FastVectorHighlighter(vectorhighlight.FastVectorHighlighter):
    """Inherited lucene FastVectorHighlighter with stored query.
    Fields must be stored and have term vectors with offsets and positions.
    
    :param searcher: `IndexSearcher`_ with stored term vectors
    :param query: lucene Query
    :param field: field name of text
    :param terms: highlight any matching term in query regardless of position
    :param fields: highlight matching terms from any field
    :param tag: optional html tag name
    :param fragListBuilder: optional lucene FragListBuilder
    :param fragmentsBuilder: optional lucene FragmentsBuilder
    """
    def __init__(self, searcher, query, field, terms=False, fields=False, tag='', fragListBuilder=None, fragmentsBuilder=None):
        if tag:
            fragmentsBuilder = vectorhighlight.SimpleFragmentsBuilder(['<{0}>'.format(tag)], ['</{0}>'.format(tag)])
        args = fragListBuilder or vectorhighlight.SimpleFragListBuilder(), fragmentsBuilder or vectorhighlight.SimpleFragmentsBuilder()
        vectorhighlight.FastVectorHighlighter.__init__(self, not terms, not fields, *args)
        self.searcher, self.field = searcher, field
        self.query = self.getFieldQuery(query)
    def fragments(self, id, count=1, size=100):
        """Return highlighted text fragments.
        
        :param id: document id
        :param count: maximum number of fragments
        :param size: maximum number of characters in fragment
        """
        return list(self.getBestFragments(self.query, self.searcher.indexReader, id, self.field, size, count))

class SpellChecker(dict):
    """Correct spellings and suggest words for queries.
    Supply a vocabulary mapping words to (reverse) sort keys, such as document frequencies.
    """
    def __init__(self, *args, **kwargs):
        dict.__init__(self, *args, **kwargs)
        self.words = sorted(self)
        self.alphabet = sorted(set(itertools.chain.from_iterable(self.words)))
        self.suffix = self.alphabet[-1] * max(map(len, self.words)) if self.alphabet else ''
        self.prefixes = set(word[:stop] for word in self.words for stop in range(len(word) + 1))
    def suggest(self, prefix, count=None):
        "Return ordered suggested words for prefix."
        start = bisect.bisect_left(self.words, prefix)
        stop = bisect.bisect_right(self.words, prefix + self.suffix, start)
        words = self.words[start:stop]
        if count is not None and count < len(words):
            return heapq.nlargest(count, words, key=self.__getitem__)
        words.sort(key=self.__getitem__, reverse=True)
        return words
    def edits(self, word, length=0):
        "Return set of potential words one edit distance away, mapped to valid prefix lengths."
        pairs = [(word[:index], word[index:]) for index in range(len(word) + 1)]
        deletes = (head + tail[1:] for head, tail in pairs[:-1])
        transposes = (head + tail[1::-1] + tail[2:] for head, tail in pairs[:-2])
        edits = {} if length else dict.fromkeys(itertools.chain(deletes, transposes), 0)
        for head, tail in pairs[length:]:
            if head not in self.prefixes:
                break
            for char in self.alphabet:
                prefix = head + char
                if prefix in self.prefixes:
                    edits[prefix + tail] = edits[prefix + tail[1:]] = len(prefix)
        return edits
    def correct(self, word):
        "Generate ordered sets of words by increasing edit distance."
        previous, edits = set(), {word: 0}
        for distance in range(len(word)):
            yield sorted(filter(self.__contains__, edits), key=self.__getitem__, reverse=True)
            previous.update(edits)
            groups = map(self.edits, edits, edits.values())
            edits = dict((edit, group[edit]) for group in groups for edit in group if edit not in previous)

class SpellParser(PythonQueryParser):
    """Inherited lucene QueryParser which corrects spelling.
    Assign a searcher attribute or override :meth:`correct` implementation.
    """
    def correct(self, term):
        "Return term with text replaced as necessary."
        field = term.field()
        for text in self.searcher.correct(field, term.text()):
            return index.Term(field, text)
        return term
    def rewrite(self, query):
        "Return term or phrase query with corrected terms substituted."
        if search.TermQuery.instance_(query):
            term = search.TermQuery.cast_(query).term
            return search.TermQuery(self.correct(term))
        query = search.PhraseQuery.cast_(query)
        phrase = search.PhraseQuery()
        for position, term in zip(query.positions, query.terms):
            phrase.add(self.correct(term), position)
        return phrase
    def getFieldQuery_quoted(self, *args):
        return self.rewrite(self.getFieldQuery_quoted_super(*args))
    def getFieldQuery_slop(self, *args):
        return self.rewrite(self.getFieldQuery_slop_super(*args))