Commits

coady  committed c471f2e

Grouping collectors from contrib module.

  • Participants
  • Parent commits 21e736b

Comments (0)

Files changed (4)

File docs/engine.rst

 
   .. attribute:: filters
 
-    Mapping of cached filters by field, which are used for facet counts.
+    Mapping of cached filters by field, also used for facet counts.
+
+  .. attribute:: groupings
+
+    Mapping of cached groupings by field, optimized for facet counts of unique fields.
 
   .. attribute:: sorters
 
 
   .. automethod:: __getitem__
 
+Grouping
+^^^^^^^^^^^^^
+.. versionadded:: 1.2+
+  requires grouping contrib module in lucene >= 3.3
+.. note:: This interface is experimental and might change in incompatible ways in the next release.
+.. autoclass:: Grouping
+  :members:
+
+  .. automethod:: __len__
+
+  .. automethod:: __iter__
+
 Field
 ^^^^^^^^^^^^^
 .. autoclass:: Field

File lupyne/engine/documents.py

         "Return `Hits`_ sorted by key function applied to doc ids."
         scoredocs = sorted(self.scoredocs, key=lambda scoredoc: key(scoredoc.doc), reverse=reverse)
         return type(self)(self.searcher, scoredocs, self.count, self.maxscore, self.fields)
+
+class Grouping(object):
+    """Delegated lucene SearchGroups with optimized faceting.
+    
+    :param searcher: `IndexSearcher`_ which can retrieve documents
+    :param field: unique field name to group by
+    :param query: lucene Query to select groups
+    :param count: maximum number of groups
+    :param sort: lucene Sort to order groups
+    """
+    def __init__(self, searcher, field, query=None, count=None, sort=None):
+        self.searcher, self.field = searcher, field
+        self.query = query or lucene.MatchAllDocsQuery()
+        self.sort = sort or lucene.Sort.RELEVANCE
+        if count is None:
+            collector = lucene.TermAllGroupsCollector(field)
+            lucene.IndexSearcher.search(self.searcher, self.query, collector)
+            count = collector.groupCount
+        collector = lucene.TermFirstPassGroupingCollector(field, self.sort, count)
+        lucene.IndexSearcher.search(self.searcher, self.query, collector)
+        self.searchgroups = collector.getTopGroups(0, False).of_(lucene.SearchGroup)
+    def __len__(self):
+        return self.searchgroups.size()
+    def __iter__(self):
+        for searchgroup in self.searchgroups:
+            yield searchgroup.groupValue.toString()
+    def facets(self, filter):
+        "Generate field values and counts which match given filter."
+        collector = lucene.TermSecondPassGroupingCollector(self.field, self.searchgroups, self.sort, self.sort, 1, False, False, False)
+        lucene.IndexSearcher.search(self.searcher, self.query, filter, collector)
+        for groupdocs in collector.getTopGroups(0).groups:
+            yield groupdocs.groupValue.toString(), groupdocs.totalHits
+    def groups(self, count=1, sort=None, scores=False, maxscore=False):
+        """Generate grouped `Hits`_ from second pass grouping collector.
+        
+        :param count: maximum number of docs per group
+        :param sort: lucene Sort to order docs within group
+        :param scores: compute scores for candidate results
+        :param maxscore: compute maximum score of all results
+        """
+        sort = sort or self.sort
+        if sort == lucene.Sort.RELEVANCE:
+            scores = maxscore = True
+        collector = lucene.TermSecondPassGroupingCollector(self.field, self.searchgroups, self.sort, sort, count, scores, maxscore, False)
+        lucene.IndexSearcher.search(self.searcher, self.query, collector)
+        for groupdocs in collector.getTopGroups(0).groups:
+            hits = Hits(self.searcher, groupdocs.scoreDocs, groupdocs.totalHits, groupdocs.maxScore, getattr(self, 'fields', None))
+            hits.value = groupdocs.groupValue.toString()
+            yield hits

File lupyne/engine/indexers.py

 import warnings
 import lucene
 from .queries import Query, TermsFilter, SortField, Highlighter, FastVectorHighlighter, SpellChecker, SpellParser
-from .documents import Field, Document, Hits
+from .documents import Field, Document, Hits, Grouping
 from .spatial import DistanceComparator
 
 class Atomic(object):
         self.owned = closing([self.indexReader])
         self.analyzer = self.shared.analyzer(analyzer)
         self.filters, self.sorters, self.spellcheckers = {}, {}, {}
-        self.termsfilters = set()
+        self.termsfilters, self.groupings = set(), {}
     @classmethod
     def load(cls, directory, analyzer=None):
         "Open `IndexSearcher`_ with a lucene RAMDirectory, loading index into memory."
             query = lucene.CachingWrapperFilter(query)
         for key in keys:
             filters = self.filters.get(key)
-            if isinstance(filters, lucene.Filter):
+            if key in self.groupings:
+                counts[key] = dict(self.groupings[key].facets(query))
+            elif isinstance(filters, lucene.Filter):
                 counts[key] = self.overlap(query, filters)
             else:
                 name, value = (key, None) if isinstance(key, basestring) else key
                         filters[value] = Query.term(name, value).filter()
                     counts[name][value] = self.overlap(query, filters[value])
         return dict(counts)
+    def grouping(self, field, query=None, count=None, sort=None):
+        "Return `Grouping`_ for unique field and lucene search parameters."
+        try:
+            return self.groupings[field]
+        except KeyError:
+            return Grouping(self, field, query, count, sort)
     def sorter(self, field, type='string', parser=None, reverse=False):
         "Return `SortField`_ with cached attributes if available."
         sorter = self.sorters.get(field, SortField(field, type, parser, reverse))

File test/local.py

         assert la == 'CA.Los Angeles' and facets[la] > 100
         assert orange == 'CA.Orange' and facets[orange] > 10
         (field, facets), = indexer.facets(query, ('state.county', 'CA.*')).items()
-        assert all(value.startswith('CA.') for value in facets) and set(facets) < set(indexer.filters['state.county'])
+        assert all(value.startswith('CA.') for value in facets) and set(facets) < set(indexer.filters[field])
+        if hasattr(lucene, 'TermFirstPassGroupingCollector'):
+            assert set(indexer.grouping('state', count=1)) < set(indexer.grouping('state')) == set(states)
+            grouping = indexer.grouping(field, query, sort=lucene.Sort(indexer.sorter(field)))
+            assert len(grouping) == 2 and list(grouping) == [la, orange]
+            for value, (name, count) in zip(grouping, grouping.facets(None)):
+                assert value == name and count > 0
+            grouping = indexer.groupings[field] = indexer.grouping(field, engine.Query.term('state', 'CA'))
+            assert indexer.facets(query, field)[field] == facets
+            hits = next(grouping.groups())
+            assert hits.value == 'CA.Los Angeles' and hits.count > 100 and len(hits) == 1
+            hit, = hits
+            assert hit['county'] == 'Los Angeles' and hits.maxscore >= hit.score > 0
+            hits = next(grouping.groups(count=2, sort=lucene.Sort(indexer.sorter('zipcode')), scores=True))
+            assert hits.value == 'CA.Los Angeles' and math.isnan(hits.maxscore) and len(hits) == 2
+            assert all(hit.score > 0 and hit['zipcode'] > '90000' for hit in hits)
         for count in (None, len(indexer)):
             hits = indexer.search(query, count=count, timeout=0.01)
             assert 0 <= len(hits) <= indexer.count(query) and hits.count in (None, len(hits)) and hits.maxscore in (None, 1.0)