Matt Chaput avatar Matt Chaput committed 03475e7

Save the score/sort keys in the group lists and then sort the group lists by the scores/keys
so docs in the group lists are in the same order as in the full results.

See issue #197.

Comments (0)

Files changed (2)

src/whoosh/searching.py

         self.timer = None
         self.timedout = True
 
-    def _add_to_group(self, name, key, offsetid):
+    def _add_to_group(self, name, key, offsetid, sortkey):
         if self.groupids:
-            self.groups[name][key].append(offsetid)
+            self.groups[name][key].append((sortkey, offsetid))
         else:
             self.groups[name][key] += 1
 
-    def collect(self, id, offsetid):
+    def collect(self, id, offsetid, sortkey):
         docset = self.docset
         if docset is not None:
             docset.add(offsetid)
             for name, catter in self.categorizers.items():
                 if catter.allow_overlap:
                     for key in catter.keys_for_id(id):
-                        add(name, catter.key_to_name(key), offsetid)
+                        add(name, catter.key_to_name(key), offsetid, sortkey)
                 else:
                     key = catter.key_to_name(catter.key_for_id(id))
-                    add(name, key, offsetid)
+                    add(name, key, offsetid, sortkey)
 
     def search(self, searcher, q, allow=None, restrict=None):
         """Top-level method call which uses the given :class:`Searcher` and
             if ((not allow or offsetid in allow)
                 and (not restrict or offsetid not in restrict)):
                 # Collect and yield this document
-                collect(id, offsetid)
                 if scorefn:
                     score = scorefn(matcher)
+                    collect(id, offsetid, score)
                 else:
                     score = matcher.score()
+                    collect(id, offsetid, 0 - score)
                 yield (score, offsetid)
 
             # If recording terms, add the document to the termlists
             if ((not allow or offsetid in allow)
                 and (not restrict or offsetid not in restrict)):
                 # Collect and yield this document
-                collect(id, offsetid)
-                yield (keyfn(id), offsetid)
+                key = keyfn(id)
+                collect(id, offsetid, key)
+                yield (key, offsetid)
 
             # Check whether the time limit expired
             if timelimited and self.timedout:
         if name not in self._groups:
             raise KeyError("%r not in group names %r"
                            % (name, self._groups.keys()))
-        return dict(self._groups[name])
+        # Sort the groups and remove the sort keys before returning them
+        groups = self._groups[name]
+        d = {}
+        for key, items in iteritems(groups):
+            d[key] = [docnum for _, docnum in sorted(items)]
+        return d
 
     def _load_docs(self):
         # If self.docset is None, that means this results object was created

tests/test_sorting.py

             multiprocessing.Process.__init__(self)
             self.storage = storage
             self.indexname = indexname
-            
+
         def run(self):
             ix = self.storage.open_index(self.indexname)
             with ix.searcher() as s:
 
 
 docs = ({"id": u("zulu"), "num": 100, "tag": u("one"), "frac": 0.75},
-        {"id": u("xray"), "num": -5, "tag": u("three"), "frac": 2.0},
+        {"id": u("xray"), "num":-5, "tag": u("three"), "frac": 2.0},
         {"id": u("yankee"), "num": 3, "tag": u("two"), "frac": 5.5},
-        
+
         {"id": u("alfa"), "num": 7, "tag": u("three"), "frac": 2.25},
         {"id": u("tango"), "num": 2, "tag": u("two"), "frac": 1.75},
-        {"id": u("foxtrot"), "num": -800, "tag": u("two"), "frac": 3.25},
-        
+        {"id": u("foxtrot"), "num":-800, "tag": u("two"), "frac": 3.25},
+
         {"id": u("sierra"), "num": 1, "tag": u("one"), "frac": 4.75},
         {"id": u("whiskey"), "num": 0, "tag": u("three"), "frac": 5.25},
         {"id": u("bravo"), "num": 582045, "tag": u("three"), "frac": 1.25},
 def make_multi_index(ix):
     for i in xrange(0, len(docs), 3):
         w = ix.writer()
-        for doc in docs[i:i+3]:
+        for doc in docs[i:i + 3]:
             w.add_document(ev=u("a"), **doc)
         w.commit(merge=False)
 
 def try_sort(sortedby, key, q=None, limit=None, reverse=False):
     if q is None: q = query.Term("ev", u("a"))
-    
+
     correct = [d["id"] for d in sorted(docs, key=key, reverse=reverse)][:limit]
-    
+
     for fn in (make_single_index, make_multi_index):
         with TempIndex(get_schema()) as ix:
             fn(ix)
         w.add_document(tag=u("juliet"))
         w.add_document(tag=u("romeo"))
         w.commit()
-        
+
         with ix.reader() as r:
             _ = r.fieldcache("tag")
             assert_equal(list(r.lexicon("tag")), ["alfa", "juliet", "romeo", "sierra"])
     with TempIndex(schema, "floatcache") as ix:
         w = ix.writer()
         w.add_document(id=1, num=1.5)
-        w.add_document(id=2, num=-8.25)
+        w.add_document(id=2, num= -8.25)
         w.add_document(id=3, num=0.75)
         w.commit()
-        
+
         with ix.reader() as r:
             r.fieldcache("num")
             r.unload_fieldcache("num")
-            
+
             fc = r.fieldcache("num")
             assert not fc.hastexts
             assert_equal(fc.texts, None)
     with TempIndex(schema, "longcache") as ix:
         w = ix.writer()
         w.add_document(id=1, num=2858205080241)
-        w.add_document(id=2, num=-3572050858202)
+        w.add_document(id=2, num= -3572050858202)
         w.add_document(id=3, num=4985020582043)
         w.commit()
-        
+
         with ix.reader() as r:
             r.fieldcache("num")
             r.unload_fieldcache("num")
-            
+
             fc = r.fieldcache("num")
             assert not fc.hastexts
             assert_equal(fc.texts, None)
         make_single_index(ix)
         r1 = ix.reader()
         fc1 = r1.fieldcache("id")
-        
+
         r2 = ix.reader()
         fc2 = r2.fieldcache("id")
-        
+
         assert fc1 is fc2
-        
+
         r3 = ix.reader()
         assert r3.fieldcache_loaded("id")
-        
+
         r1.close()
         r2.close()
         del r1, fc1, r2, fc2
         import gc
         gc.collect()
-        
+
         assert not r3.fieldcache_loaded("id")
         r3.close()
 
         for char in domain:
             w.add_document(key=char)
         w.commit()
-        
+
         tasks = [MPFCTask(ix.storage, ix.indexname) for _ in xrange(4)]
         for task in tasks:
             task.start()
     try_sort("id", lambda d: d["id"])
     try_sort("id", lambda d: d["id"], limit=5)
     try_sort("id", lambda d: d["id"], reverse=True)
-    try_sort("id",  lambda d: d["id"], limit=5, reverse=True)
+    try_sort("id", lambda d: d["id"], limit=5, reverse=True)
 
 def test_multisort():
     mf = sorting.MultiFacet(["tag", "id"])
         w.add_document(id=2)
         w.add_document(id=3)
         w.commit()
-        
+
         with ix.searcher() as s:
             r = s.search(query.Every(), sortedby="key")
             assert_equal([h["id"] for h in r], [1, 2, 3])
     with TempIndex(schema, "pagesorted") as ix:
         domain = list(u("abcdefghijklmnopqrstuvwxyz"))
         random.shuffle(domain)
-        
+
         w = ix.writer()
         for char in domain:
             w.add_document(key=char)
         w.commit()
-        
+
         with ix.searcher() as s:
             r = s.search(query.Every(), sortedby="key", limit=5)
             assert_equal(r.scored_length(), 5)
             assert_equal(len(r), s.doc_count_all())
-            
+
             rp = s.search_page(query.Every(), 1, pagelen=5, sortedby="key")
             assert_equal("".join([h["key"] for h in rp]), "abcde")
             assert_equal(rp[10:], [])
-            
+
             rp = s.search_page(query.Term("key", "glonk"), 1, pagelen=5, sortedby="key")
             assert_equal(len(rp), 0)
             assert rp.is_last_page()
     w.add_document(id=5, a=u("alfa bravo bravo"), b=u("apple"), c=u("c"))
     w.add_document(id=6, a=u("alfa alfa alfa"), b=u("apple"), c=u("c"))
     w.commit(merge=False)
-    
+
     with ix.searcher() as s:
         facet = sorting.MultiFacet(["b", sorting.ScoreFacet()])
         r = s.search(q=query.Term("a", u("alfa")), sortedby=facet)
                     w.add_document(id=count, text=u(" ").join((w1, w2, w3, w4)))
                     count += 1
     w.commit()
-    
+
     def fn(searcher, docnum):
         v = dict(searcher.vector_as("frequency", docnum, "text"))
         # Give high score to documents that have equal number of "alfa"
         # and "bravo". Negate value so higher values sort first
         return 0 - (1.0 / (abs(v.get("alfa", 0) - v.get("bravo", 0)) + 1.0))
-    
+
     with ix.searcher() as s:
         q = query.And([query.Term("text", u("alfa")), query.Term("text", u("bravo"))])
-        
+
         fnfacet = sorting.FunctionFacet(fn)
         r = s.search(q, sortedby=fnfacet)
         texts = [hit["text"] for hit in r]
     w.add_document(id=5, v1=2, v2=50)
     w.add_document(id=6, v1=1, v2=200)
     w.commit()
-    
+
     with ix.searcher() as s:
         mf = sorting.MultiFacet().add_field("v1").add_field("v2", reverse=True)
         r = s.search(query.Every(), sortedby=mf)
         w = ix.writer()
         w.add_document(id=i, v=ltr)
         w.commit(merge=False)
-    
+
     with ix.searcher() as s:
         q1 = query.TermRange("v", "a", "c")
         q2 = query.TermRange("v", "d", "f")
         q3 = query.TermRange("v", "g", "i")
-        
+
         assert_equal([hit["id"] for hit in s.search(q1)], [1, 2, 4])
         assert_equal([hit["id"] for hit in s.search(q2)], [5, 7, 8])
         assert_equal([hit["id"] for hit in s.search(q3)], [0, 3, 6])
-        
+
         facet = sorting.QueryFacet({"a-c": q1, "d-f": q2, "g-i": q3})
         r = s.search(query.Every(), groupedby=facet)
         # If you specify a facet without a name, it's automatically called
         for i, ltr in enumerate(domain):
             v = "%s %s" % (ltr, domain[0 - i])
             w.add_document(v=v)
-    
+
     with ix.searcher() as s:
         q1 = query.TermRange("v", "a", "c")
         q2 = query.TermRange("v", "d", "f")
         q3 = query.TermRange("v", "g", "i")
-        
+
         facets = sorting.Facets()
         facets.add_query("myfacet", {"a-c": q1, "d-f": q2, "g-i": q3}, allow_overlap=True)
         r = s.search(query.Every(), groupedby=facets)
     w.add_document(id=3, tag=u("bravo"))
     w.add_document(id=4)
     w.commit()
-    
+
     with ix.searcher() as s:
         r = s.search(query.Every(), groupedby="tag")
         assert_equal(r.groups("tag"), {None: [2, 4], 'bravo': [3], 'alfa': [0, 1]})
     w.add_document(id=3, tag=0)
     w.add_document(id=4)
     w.commit()
-    
+
     with ix.searcher() as s:
         r = s.search(query.Every(), groupedby="tag")
         assert_equal(r.groups("tag"), {None: [2, 4], 0: [3], 1: [0, 1]})
     w.add_document(id=3, date=d2)
     w.add_document(id=4)
     w.commit()
-    
+
     with ix.searcher() as s:
         r = s.search(query.Every(), groupedby="date")
-        assert_equal(r.groups("date"),  {d1: [0, 1], d2: [3], None: [2, 4]})
+        assert_equal(r.groups("date"), {d1: [0, 1], d2: [3], None: [2, 4]})
 
 def test_range_facet():
     schema = fields.Schema(id=fields.STORED, price=fields.NUMERIC)
     w.add_document(id=4, price=500)
     w.add_document(id=5, price=125)
     w.commit()
-    
+
     with ix.searcher() as s:
         rf = sorting.RangeFacet("price", 0, 1000, 100)
         r = s.search(query.Every(), groupedby={"price": rf})
     for i in range(10):
         w.add_document(id=i, num=i)
     w.commit()
-    
+
     with ix.searcher() as s:
-        rf = sorting.RangeFacet("num", 0, 1000, [1,2,3])
+        rf = sorting.RangeFacet("num", 0, 1000, [1, 2, 3])
         r = s.search(query.Every(), groupedby={"num": rf})
         assert_equal(r.groups("num"), {(0, 1): [0],
                                        (1, 3): [1, 2],
     w.add_document(id=4, date=datetime(2001, 1, 8))
     w.add_document(id=5, date=datetime(2001, 1, 6))
     w.commit()
-    
+
     with ix.searcher() as s:
         rf = sorting.DateRangeFacet("date", datetime(2001, 1, 1),
                                     datetime(2001, 1, 20), timedelta(days=5))
 def test_relative_daterange():
     from whoosh.support.relativedelta import relativedelta
     dt = datetime
-    
+
     schema = fields.Schema(id=fields.STORED, date=fields.DATETIME)
     ix = RamStorage().create_index(schema)
     basedate = datetime(2001, 1, 1)
             w.add_document(id=count, date=basedate)
             basedate += timedelta(days=14, hours=16)
             count += 1
-    
+
     with ix.searcher() as s:
         gap = relativedelta(months=1)
         rf = sorting.DateRangeFacet("date", dt(2001, 1, 1), dt(2001, 12, 31), gap)
         w.add_document(id=2, tags=u("charlie delta echo"))
         w.add_document(id=3, tags=u("delta echo alfa"))
         w.add_document(id=4, tags=u("echo alfa bravo"))
-    
+
     with ix.searcher() as s:
         of = sorting.FieldFacet("tags", allow_overlap=True)
         r = s.search(query.Every(), groupedby={"tags": of})
                         == [(u('one'), [0, 6]),
                             (u('three'), [1, 3, 7, 8]),
                             (u('two'), [2, 4, 5])])
-    
+
     check(make_single_index)
     check(make_multi_index)
 
         w.add_document(tag=u("alfa"), size=u("medium"))
         w.add_document(tag=u("bravo"), size=u("medium"))
         w.commit()
-        
+
         correct = {(u('bravo'), u('medium')): [1, 5], (u('alfa'), u('large')): [2],
                    (u('alfa'), u('medium')): [4], (u('alfa'), u('small')): [0],
                    (u('bravo'), u('small')): [3]}
-        
+
         with ix.searcher() as s:
             facet = sorting.MultiFacet(["tag", "size"])
             r = s.search(query.Every(), groupedby={"tag/size" : facet})
         group = groups[i % len(groups)]
         source.append({"key": key, "group": group})
     source.sort(key=lambda x: (x["key"], x["group"]))
-    
+
     sample = source[:]
     random.shuffle(sample)
-    
+
     with TempIndex(schema, "sortfilter") as ix:
         w = ix.writer()
         for i, fs in enumerate(sample):
                 w.commit(merge=False)
                 w = ix.writer()
         w.commit()
-        
+
         fq = query.Term("group", u("bravo"))
-        
+
         with ix.searcher() as s:
             r = s.search(query.Every(), sortedby=("key", "group"), filter=fq, limit=20)
             assert_equal([h.fields() for h in r],
                          [d for d in source if d["group"] == "bravo"][:20])
-            
+
             fq = query.Term("group", u("bravo"))
             r = s.search(query.Every(), sortedby=("key", "group"), filter=fq, limit=None)
             assert_equal([h.fields() for h in r],
                          [d for d in source if d["group"] == "bravo"])
-            
+
         ix.optimize()
-        
+
         with ix.searcher() as s:
             r = s.search(query.Every(), sortedby=("key", "group"), filter=fq, limit=20)
             assert_equal([h.fields() for h in r],
                          [d for d in source if d["group"] == "bravo"][:20])
-            
+
             fq = query.Term("group", u("bravo"))
             r = s.search(query.Every(), sortedby=("key", "group"), filter=fq, limit=None)
             assert_equal([h.fields() for h in r],
     schema = fields.Schema(name=fields.ID(stored=True),
                            price=fields.NUMERIC,
                            quant=fields.NUMERIC)
-    
+
     with TempIndex(schema, "customsort") as ix:
         w = ix.writer()
         w.add_document(name=u("A"), price=200, quant=9)
         w.add_document(name=u("B"), price=250, quant=11)
         w.add_document(name=u("C"), price=200, quant=10)
         w.commit()
-        
+
         with ix.searcher() as s:
             cs = s.sorter()
             cs.add_field("price")
             cs.add_field("quant", reverse=True)
             r = cs.sort_query(query.Every(), limit=None)
             assert_equal([hit["name"] for hit in r], list(u("DCAFBE")))
-            
+
 def test_sorting_function():
     schema = fields.Schema(id=fields.STORED, text=fields.TEXT(stored=True, vector=True))
     ix = RamStorage().create_index(schema)
                     w.add_document(id=count, text=u(" ").join((w1, w2, w3, w4)))
                     count += 1
     w.commit()
-    
+
     def fn(searcher, docnum):
         v = dict(searcher.vector_as("frequency", docnum, "text"))
         # Sort documents that have equal number of "alfa"
         # and "bravo" first
         return 0 - 1.0 / (abs(v.get("alfa", 0) - v.get("bravo", 0)) + 1.0)
     fnfacet = sorting.FunctionFacet(fn)
-    
+
     with ix.searcher() as s:
         q = query.And([query.Term("text", u("alfa")), query.Term("text", u("bravo"))])
         results = s.search(q, sortedby=fnfacet)
             tks = t.split()
             assert_equal(tks.count("alfa"), tks.count("bravo"))
 
+def test_sorted_groups():
+    schema = fields.Schema(a=fields.STORED, b=fields.TEXT, c=fields.ID)
+    ix = RamStorage().create_index(schema)
+    with ix.writer() as w:
+        w.add_document(a=0, b=u("blah"), c=u("apple"))
+        w.add_document(a=1, b=u("blah blah"), c=u("bear"))
+        w.add_document(a=2, b=u("blah blah blah"), c=u("apple"))
+        w.add_document(a=3, b=u("blah blah blah blah"), c=u("bear"))
+        w.add_document(a=4, b=u("blah blah blah blah blah"), c=u("apple"))
+        w.add_document(a=5, b=u("blah blah blah blah blah blah"), c=u("bear"))
 
-#def test_custom_sort2():
-#    from array import array
-#    from whoosh.searching import Results
-#    
-#    class CustomSort(object):
-#        def __init__(self, *criteria):
-#            self.criteria = criteria
-#            self.arrays = None
-#            
-#        def cache(self, searcher):
-#            self.arrays = []
-#            r = searcher.reader()
-#            for name, reverse in self.criteria:
-#                arry = array("i", [0] * r.doc_count_all())
-#                field = ix.schema[name]
-#                for i, (token, _) in enumerate(field.sortable_values(r, name)):
-#                    if reverse: i = 0 - i
-#                    postings = r.postings(name, token)
-#                    for docid in postings.all_ids():
-#                        arry[docid] = i
-#                self.arrays.append(arry)
-#                
-#        def key_fn(self, docnum):
-#            return tuple(arry[docnum] for arry in self.arrays)
-#        
-#        def sort_query(self, searcher, q):
-#            if self.arrays is None:
-#                self.cache(searcher)
-#            
-#            return self._results(searcher, q, searcher.docs_for_query(q))
-#        
-#        def sort_all(self, searcher):
-#            if self.arrays is None:
-#                self.cache(searcher)
-#            
-#            return self._results(searcher, None, searcher.reader().all_doc_ids())
-#            
-#        def _results(self, searcher, q, docnums):
-#            docnums = sorted(docnums, key=self.key_fn)
-#            return Results(searcher, q, [(None, docnum) for docnum in docnums], None)
-#            
-#    
-#    schema = fields.Schema(name=fields.ID(stored=True),
-#                           price=fields.NUMERIC,
-#                           quant=fields.NUMERIC)
-#    
-#    with TempIndex(schema, "customsort") as ix:
-#        w = ix.writer()
-#        w.add_document(name=u("A"), price=200, quant=9)
-#        w.add_document(name=u("E"), price=300, quant=4)
-#        w.add_document(name=u("F"), price=200, quant=8)
-#        w.add_document(name=u("D"), price=150, quant=5)
-#        w.add_document(name=u("B"), price=250, quant=11)
-#        w.add_document(name=u("C"), price=200, quant=10)
-#        w.commit()
-#        
-#        cs = CustomSort(("price", False), ("quant", True))
-#        with ix.searcher() as s:
-#            assert_equal([hit["name"] for hit in cs.sort_query(s, query.Every())],
-#                          list("DCAFBE"))
+    with ix.searcher() as s:
+        q = query.Term("b", "blah")
+        r = s.search(q, groupedby="c")
+        gs = r.groups("c")
+        assert_equal(gs["apple"], [4, 2, 0])
+        assert_equal(gs["bear"], [5, 3, 1])
+
+
+
+
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.