Commits

coady  committed 434bd85

Hit sorting and filtering.

  • Participants
  • Parent commits 2d0e33a

Comments (0)

Files changed (2)

File lupyne/engine/documents.py

             group.ids.append(id)
             group.scores.append(score)
         return sorted(groups.values(), key=lambda group: group.__dict__.pop('index'))
+    def filter(self, func):
+        "Return `Hits`_ filtered by function applied to doc ids."
+        ids, scores = [], []
+        for id, score in self.items():
+            if func(id):
+                ids.append(id)
+                scores.append(score)
+        return type(self)(self.searcher, ids, scores, fields=self.fields)
+    def sorted(self, key, reverse=False):
+        "Return `Hits`_ sorted by key function applied to doc ids."
+        ids = sorted(self.ids, key=key, reverse=reverse)
+        scores = list(map(dict(self.items()).__getitem__, ids))
+        return type(self)(self.searcher, ids, scores, self.count, self.maxscore, self.fields)

File test/local.py

         assert 1 == counts[0] < counts[2] < counts[1]
         assert len(field.within(x, y, 10**8)) == 1
         self.assertRaises(NameError, list, field.radiate(y, x, 1, 0))
+        hits = hits.filter(lambda id: distances[id] < 10**4)
+        assert 0 < len(hits) < sum(counts.values())
+        hits = hits.sorted(distances.__getitem__, reverse=True)
+        assert 0 == distances[hits.ids[-1]] < distances[hits.ids[0]] < 10**4
         if hasattr(lucene, 'LatLongDistanceFilter'):
             with assertWarns(DeprecationWarning):
                 f = field.filter(x, y, 10**4, 'longitude', 'latitude')