Commits

Matt Chaput committed e16c7bc

Finished exclusive and open brackets on numeric ranges. Added ConstantScoreQuery.
Numeric ranges now wrap the compiled query in ConstantScoreQuery by default.
Renamed minbound/maxbound to start/end in numeric.
Renamed do() thow-away funcs in tests to check().

Comments (0)

Files changed (6)

src/whoosh/matching.py

     
     def score(self):
         return self.child.score() * self.boost
-    
+
 
 class MultiMatcher(Matcher):
     """Serializes the results of a list of sub-matchers.
         
     def value_as(self, astype):
         return self.a.value_as(astype)
+
+
+class ConstantScoreMatcher(WrappingMatcher):
+    def __init__(self, child, score=1.0):
+        super(ConstantScoreMatcher, self).__init__(child)
+        self._score = score
     
+    def quality(self):
+        return self._score
+    
+    def block_quality(self):
+        return self._score
+    
+    def score(self):
+        return self._score
+    
+
+
 
 #class PhraseMatcher(WrappingMatcher):
 #    """Matches postings where a list of sub-matchers occur next to each other

src/whoosh/query.py

 from whoosh.matching import (AndMaybeMatcher, DisjunctionMaxMatcher,
                              ListMatcher, IntersectionMatcher, InverseMatcher,
                              NullMatcher, RequireMatcher, UnionMatcher,
-                             WrappingMatcher)
+                             WrappingMatcher, ConstantScoreMatcher)
 from whoosh.reading import TermNotFound
 from whoosh.support.bitvector import BitVector
 from whoosh.support.levenshtein import relative
         return visitor(copy.deepcopy(self))
 
 
+class WrappingQuery(Query):
+    def __init__(self, child):
+        self.child = child
+        
+    def copy(self):
+        return self.__class__(self.child)
+    
+    def all_terms(self, termset=None, phrases=True):
+        return self.child.all_terms(termset=termset, phrases=phrases)
+    
+    def existing_terms(self, ixreader, termset=None, reverse=False,
+                       phrases=True):
+        return self.child.existing_terms(ixreader, termset=termset,
+                                         reverse=reverse, phrases=phrases)
+    
+    def estimate_size(self, ixreader):
+        return self.child.estimate_size(ixreader)
+    
+    def matcher(self, searcher, exclude_docs=None):
+        return self.child.matcher(searcher, exclude_docs=exclude_docs)
+    
+    def replace(self, oldtext, newtext):
+        return self.__class__(self.child.replace(oldtext, newtext))
+    
+    def accept(self, visitor):
+        return self.__class__(self.child.accept(visitor))
+
+
 class CompoundQuery(Query):
     """Abstract base class for queries that combine or manipulate the results
     of multiple sub-queries .
 
     def normalize(self):
         if self.start == self.end:
+            if self.startexcl or self.endexcl:
+                return NullQuery
             return Term(self.fieldname, self.start, boost=self.boost)
         else:
             return TermRange(self.fieldname, self.start, self.end,
     """
     
     def __init__(self, fieldname, start, end, startexcl=False, endexcl=False,
-                 boost=1.0):
+                 boost=1.0, constantscore=True):
         """
         :param fieldname: The name of the field to search.
         :param start: Match terms equal to or greater than this number. This
             range end is inclusive.
         :param boost: Boost factor that should be applied to the raw score of
             results matched by this query.
+        :param constantscore: If True, the compiled query returns a constant
+            score (the value of the ``boost`` keyword argument) instead of
+            actually scoring the matched terms. This gives a nice speed boost
+            and won't affect the results in most cases since numeric ranges
+            will almost always be used as a filter.
         """
 
         self.fieldname = fieldname
         self.startexcl = startexcl
         self.endexcl = endexcl
         self.boost = boost
+        self.constantscore = constantscore
     
     def __repr__(self):
         return '%s(%r, %r, %r, %s, %s)' % (self.__class__.__name__,
         if self.endexcl: endchar = "}"
         return u"%s:%s%s TO %s%s" % (self.fieldname,
                                      startchar, self.start, self.end, endchar)
-        
+    
     def copy(self):
         return NumericRange(self.fieldname, self.start, self.end,
                             self.startexcl, self.endexcl, boost=self.boost)
         subqueries = []
         # Get the term ranges for the different resolutions
         for starttext, endtext in tiered_ranges(field.type, self.start,
-                                                self.end, field.shift_step):
+                                                self.end, field.shift_step,
+                                                self.startexcl, self.endexcl):
             if starttext == endtext:
                 subq = Term(self.fieldname, starttext)
             else:
             subqueries.append(subq)
         
         if len(subqueries) == 1:
-            return subqueries[0] 
+            q = subqueries[0] 
         elif subqueries:
-            return Or(subqueries, boost=self.boost)
+            q = Or(subqueries, boost=self.boost)
         else:
             return NullQuery
         
+        if self.constantscore:
+            q = ConstantScoreQuery(q, self.boost)
+        return q
+        
     def matcher(self, searcher, exclude_docs=None):
         q = self._compile_query(searcher.reader())
         return q.matcher(searcher, exclude_docs=exclude_docs)
 NullQuery = NullQuery()
 
 
+class ConstantScoreQuery(WrappingQuery):
+    def __init__(self, child, score=1.0):
+        super(ConstantScoreQuery, self).__init__(child)
+        self.score = score
+    
+    def copy(self):
+        return ConstantScoreQuery(self.child, self.score)
+    
+    def matcher(self, searcher, exclude_docs=None):
+        m = self.child.matcher(searcher, exclude_docs=None)
+        if isinstance(m, NullMatcher):
+            return m
+        else:
+            return ConstantScoreMatcher(m, self.score)
+        
+    def replace(self, oldtext, newtext):
+        return self.__class__(self.child.replace(oldtext, newtext), self.score)
+    
+    def accept(self, visitor):
+        return self.__class__(self.child.accept(visitor), self.score)
+
+
 class Require(CompoundQuery):
     """Binary query returns results from the first query that also appear in
     the second query, but only uses the scores from the first query. This lets

src/whoosh/support/numeric.py

 from array import array
 
 
-def split_range(valsize, step, minbound, maxbound):
-    """Splits a range of numbers (from ``minbound`` to ``maxbound``, inclusive)
-    into a sequence of trie ranges of the form ``(start, end, shift)``.
-    The consumer of these tuples is expected to shift the ``start`` and ``end``
+def split_range(valsize, step, start, end):
+    """Splits a range of numbers (from ``start`` to ``end``, inclusive)
+    into a sequence of trie ranges of the form ``(start, end, shift)``. The
+    consumer of these tuples is expected to shift the ``start`` and ``end``
     right by ``shift``.
     
     This is used for generating term ranges for a numeric field. The queries
     while True:
         diff = 1 << (shift + step)
         mask = ((1 << step) - 1) << shift
+        setbits = lambda x: x | ((1 << shift) - 1)
         
-        haslower = (minbound & mask) != 0
-        hasupper = (maxbound & mask) != mask
+        haslower = (start & mask) != 0
+        hasupper = (end & mask) != mask
         
-        not_mask = ~mask & ((1 << valsize+1) - 1)
-        nextmin = (minbound + diff if haslower else minbound) & not_mask
-        nextmax = (maxbound - diff if hasupper else maxbound) & not_mask
+        not_mask = ~mask & ((1 << valsize + 1) - 1)
+        nextstart = (start + diff if haslower else start) & not_mask
+        nextend = (end - diff if hasupper else end) & not_mask
         
-        if shift + step >= valsize or nextmin > nextmax:
-            yield (minbound, maxbound | ((1 << shift) - 1), shift)
+        if shift + step >= valsize or nextstart > nextend:
+            yield (start, setbits(end), shift)
             break
         
         if haslower:
-            yield (minbound, (minbound | mask) | ((1 << shift) - 1), shift)
+            yield (start, setbits(start | mask), shift)
         if hasupper:
-            yield (maxbound & not_mask, maxbound | ((1 << shift) - 1), shift)
+            yield (end & not_mask, setbits(end), shift)
         
-        minbound = nextmin
-        maxbound = nextmax
+        start = nextstart
+        end = nextend
         shift += step
 
 
 
 # Functions for generating tiered ranges
 
-def tiered_ranges(numtype, start, end, shift_step):
+_max_sortable_int = 4294967295L
+_max_sortable_long = 18446744073709551615L
+
+def tiered_ranges(numtype, start, end, shift_step, startexcl, endexcl):
     # First, convert the start and end of the range to sortable representations
+    
+    valsize = 32 if numtype is int else 64
+    
+    # Convert start and end values to sortable ints
+    if start is None:
+        start = 0
+    else:
+        if numtype is int:
+            start = int_to_sortable_int(start)
+        elif numtype is long:
+            start = long_to_sortable_long(start)
+        elif numtype is float:
+            start = float_to_sortable_long(start)
+        if startexcl: start += 1
+    
+    if end is None:
+        end = _max_sortable_int if valsize == 32 else _max_sortable_long
+    else:
+        if numtype is int:
+            end = int_to_sortable_int(end)
+        elif numtype is long:
+            end = long_to_sortable_long(end)
+        elif numtype is float:
+            end = float_to_sortable_long(end)
+        if endexcl: end -= 1
+    
     if numtype is int:
-        valsize = 32
-        start = int_to_sortable_int(start)
-        end = int_to_sortable_int(end)
         to_text = sortable_int_to_text
     else:
-        valsize = 64
-        if numtype is long:
-            start = long_to_sortable_long(start)
-            end = long_to_sortable_long(end)
-        elif numtype is float:
-            # Convert floats to longs
-            start = float_to_sortable_long(start)
-            end = float_to_sortable_long(end)
         to_text = sortable_long_to_text
     
     if not shift_step:
     for rstart, rend, shift in split_range(valsize, shift_step, start, end):
         starttext = to_text(rstart, shift=shift)
         endtext = to_text(rend, shift=shift)
-        
         yield (starttext, endtext)
 
 

tests/test_analysis.py

         iwf = IntraWordFilter(mergewords=True, mergenums=True)
         ana = RegexTokenizer(r"\S+") | iwf
         
-        def do(text, ls):
+        def check(text, ls):
             self.assertEqual([(t.pos, t.text) for t in ana(text)], ls)
             
-        do(u"PowerShot", [(0, "Power"), (1, "Shot"), (1, "PowerShot")])
-        do(u"A's+B's&C's", [(0, "A"), (1, "B"), (2, "C"), (2, "ABC")])
-        do(u"Super-Duper-XL500-42-AutoCoder!", [(0, "Super"), (1, "Duper"), (2, "XL"),
-                                                (2, "SuperDuperXL"), (3, "500"), (4, "42"),
-                                                (4, "50042"), (5, "Auto"), (6, "Coder"),
-                                                (6, "AutoCoder")])
+        check(u"PowerShot", [(0, "Power"), (1, "Shot"), (1, "PowerShot")])
+        check(u"A's+B's&C's", [(0, "A"), (1, "B"), (2, "C"), (2, "ABC")])
+        check(u"Super-Duper-XL500-42-AutoCoder!", [(0, "Super"), (1, "Duper"), (2, "XL"),
+                                                   (2, "SuperDuperXL"), (3, "500"), (4, "42"),
+                                                   (4, "50042"), (5, "Auto"), (6, "Coder"),
+                                                   (6, "AutoCoder")])
     
     def test_biword(self):
         ana = RegexTokenizer(r"\w+") | BiWordFilter()

tests/test_fields.py

         floatf = fields.NUMERIC(float, shift_step=0)
         
         def roundtrip(obj, num):
-            self.assertAlmostEqual(obj.from_text(obj.to_text(num)), num, 4)
+            self.assertEqual(obj.from_text(obj.to_text(num)), num)
         
         roundtrip(intf, 0)
         roundtrip(intf, 12345)
         roundtrip(floatf, -582.592)
         roundtrip(floatf, -99.42)
         
-    def test_numeric_sort(self):
-        intf = fields.NUMERIC(int, shift_step=0)
-        longf = fields.NUMERIC(long, shift_step=0)
-        floatf = fields.NUMERIC(float, shift_step=0)
-        
         from random import shuffle
         def roundtrip_sort(obj, start, end, step):
             count = start
             shuffle(scrabled)
             round = [obj.from_text(t) for t
                      in sorted([obj.to_text(n) for n in scrabled])]
-            for n1, n2 in zip(round, rng):
-                self.assertAlmostEqual(n1, n2, 2, "n1=%r n2=%r type=%s" % (n1, n2, obj.type))
+            self.assertEqual(round, rng)
         
         roundtrip_sort(intf, -100, 100, 1)
         roundtrip_sort(longf, -58902, 58249, 43)
         r = s.search(qp.parse("0.536255"))
         self.assertEqual(r[0]["id"], "b")
     
-    def test_numeric_range(self):
-        def test_type(t, start, end, step, teststart, testend):
-            fld = fields.NUMERIC(t)
-            schema = fields.Schema(id=fields.STORED, number=fld)
-            ix = RamStorage().create_index(schema)
-            
-            w = ix.writer()
-            n = start
-            while n <= end:
-                w.add_document(id=n, number=n)
-                n += step
-            w.commit()
-            
-            qp = qparser.QueryParser("number", schema=schema)
-            q = qp.parse("[%s to %s]" % (teststart, testend))
-            self.assertEqual(q.__class__, query.NumericRange)
-            self.assertEqual(q.start, teststart)
-            self.assertEqual(q.end, testend)
-            
-            s = ix.searcher()
-            self.assertEqual(q._compile_query(s.reader()).__class__, query.Or)
-            rng = []
-            count = teststart
-            while count <= testend:
-                rng.append(count)
-                count += step
-            
-            found = [s.stored_fields(d)["id"] for d in q.docs(s)]
-            self.assertEqual(found, rng)
+    def test_numeric_parsing(self):
+        schema = fields.Schema(id=fields.ID(stored=True), number=fields.NUMERIC)
         
-        test_type(float, -50.0, 50.0, 0.5, -45.5, 39.0)
-        test_type(int, -5, 500, 1, 10, 400)
-        test_type(int, -500, 500, 5, -350, 280)
-        test_type(long, -1000, 1000, 5, -900, 90)
-    
-    def test_open_numeric_ranges(self):
-        schema = fields.Schema(id=fields.ID(stored=True),
-                               view_count=fields.NUMERIC(stored=True))
+        qp = qparser.QueryParser("number", schema=schema)
+        q = qp.parse("[10 to *]")
+        self.assertEqual(q, query.NullQuery)
+        
+        q = qp.parse("[to 400]")
+        self.assertEqual(q.__class__, query.NumericRange)
+        self.assertEqual(q.start, None)
+        self.assertEqual(q.end, 400)
+        
+        q = qp.parse("[10 to]")
+        self.assertEqual(q.__class__, query.NumericRange)
+        self.assertEqual(q.start, 10)
+        self.assertEqual(q.end, None)
+        
+        q = qp.parse("[10 to 400]")
+        self.assertEqual(q.__class__, query.NumericRange)
+        self.assertEqual(q.start, 10)
+        self.assertEqual(q.end, 400)
+        
+    def test_numeric_ranges(self):
+        schema = fields.Schema(id=fields.STORED, num=fields.NUMERIC)
         ix = RamStorage().create_index(schema)
+        w = ix.writer()
         
-        w = ix.writer()
-        for i, letter in enumerate(u"abcdefghijklmno"):
-            w.add_document(id=letter, view_count=(i + 1) * 101)
+        for i in xrange(400):
+            w.add_document(id=i, num=i)
         w.commit()
         
         s = ix.searcher()
-        #from whoosh.qparser.old import QueryParser
-        #qp = QueryParser("id", schema=schema)
-        qp = qparser.QueryParser("id", schema=schema)
+        qp = qparser.QueryParser("num", schema=schema)
         
-        def do(qstring, target):
-            q = qp.parse(qstring)
-            results = "".join(sorted([d['id'] for d in s.search(q, limit=None)]))
-            self.assertEqual(results, target, "%r: %s != %s" % (q, results, target))
+        def check(qs, target):
+            q = qp.parse(qs)
+            result = [s.stored_fields(d)["id"] for d in q.docs(s)]
+            self.assertEqual(result, target)
         
-        do(u"view_count:[0 TO]", "abcdefghijklmno")
-        do(u"view_count:[1000 TO]", "jklmno")
-        do(u"view_count:[TO 300]", "ab")
-        do(u"view_count:[200 TO 500]", "bcd")
-        do(u"view_count:{202 TO]", "cdefghijklmno")
-        do(u"view_count:[TO 505}", "abcd")
-        do(u"view_count:{202 TO 404}", "c")
+        # Note that range() is always inclusive-exclusive
+        check("[10 to 390]", range(10, 390+1))
+        check("[100 to]", range(100, 400))
+        check("[to 350]", range(0, 350+1))
+        check("[16 to 255]", range(16, 255+1))
+        check("{10 to 390]", range(11, 390+1))
+        check("[10 to 390}", range(10, 390))
+        check("{10 to 390}", range(11, 390))
+        check("{16 to 255}", range(17, 255))
     
-    def test_numeric_steps(self):
-        for step in range(0, 32):
-            schema = fields.Schema(id = fields.STORED,
-                                   num=fields.NUMERIC(int, shift_step=step))
-            ix = RamStorage().create_index(schema)
-            w = ix.writer()
-            for i in xrange(-10, 10):
-                w.add_document(id=i, num=i)
-            w.commit()
-            
-            s = ix.searcher()
-            q = query.NumericRange("num", -9, 9)
-            r = [s.stored_fields(d)["id"] for d in q.docs(s)]
-            self.assertEqual(r, range(-9, 10))
-            
     def test_datetime(self):
         schema = fields.Schema(id=fields.ID(stored=True),
                                date=fields.DATETIME(stored=True))

tests/test_searching.py

         w.commit()
         s = ix.searcher()
         
-        def do(startexcl, endexcl, string):
+        def check(startexcl, endexcl, string):
             q = TermRange("id", "b", "f", startexcl, endexcl)
             r = "".join(sorted(d['id'] for d in s.search(q)))
             self.assertEqual(r, string)
             
-        do(False, False, "bcdef")
-        do(True, False, "cdef")
-        do(True, True, "cde")
-        do(False, True, "bcde")
+        check(False, False, "bcdef")
+        check(True, False, "cdef")
+        check(True, True, "cde")
+        check(False, True, "bcde")
         
     def test_open_ranges(self):
         schema = fields.Schema(id=fields.ID(stored=True))
         #from whoosh.qparser.old import QueryParser
         #qp = QueryParser("id", schema=schema)
         qp = qparser.QueryParser("id", schema=schema)
-        def do(qstring, result):
+        def check(qstring, result):
             q = qp.parse(qstring)
             r = "".join(sorted([d['id'] for d in s.search(q)]))
             self.assertEqual(r, result)
             
-        do(u"[b TO]", "bcdefg")
-        do(u"[TO e]", "abcde")
-        do(u"[b TO d]", "bcd")
-        do(u"{b TO]", "cdefg")
-        do(u"[TO e}", "abcd")
-        do(u"{b TO d}", "c")
+        check(u"[b TO]", "bcdefg")
+        check(u"[TO e]", "abcde")
+        check(u"[b TO d]", "bcd")
+        check(u"{b TO]", "cdefg")
+        check(u"[TO e}", "abcd")
+        check(u"{b TO d}", "c")
     
     def test_keyword_or(self):
         schema = fields.Schema(a=fields.ID(stored=True), b=fields.KEYWORD)