Commits

Matt Chaput  committed 484850a Merge

Merging changes from mainline.

  • Participants
  • Parent commits cc7f53b, ceaf0a9
  • Branches betterq

Comments (0)

Files changed (3)

File src/whoosh/__init__.py

 # those of the authors and should not be interpreted as representing official
 # policies, either expressed or implied, of Matt Chaput.
 
-__version__ = (1, 8, 3)
+__version__ = (1, 8, 4)
 
 
 def versionstring(build=True, extra=True):

File src/whoosh/searching.py

         self._ix = fromindex
         
         if parent:
+            self.parent = parent
             self.schema = parent.schema
             self._doccount = parent._doccount
             self._idf_cache = parent._idf_cache
             self._filter_cache = parent._filter_cache
         else:
+            self.parent = None
             self.schema = self.ixreader.schema
             self._doccount = self.ixreader.doc_count_all()
             self._idf_cache = {}
         for name in ("stored_fields", "all_stored_fields", "vector", "vector_as",
                      "lexicon", "frequency", "doc_frequency", 
                      "min_length", "max_length", "max_weight", "max_wol",
-                     "field_length", "doc_field_length",
-                     "min_field_length", "max_field_length",
-                     ):
+                     "doc_field_length"):
             setattr(self, name, getattr(self.ixreader, name))
 
     def __enter__(self):
         
         return self._doccount
 
+    def field_length(self, fieldname):
+        if self.parent:
+            return self.parent.field_length(fieldname)
+        else:
+            return self.reader().field_length(fieldname)
+        
+    def max_field_length(self, fieldname):
+        if self.parent:
+            return self.parent.max_field_length(fieldname)
+        else:
+            return self.reader().max_field_length(fieldname)
+
     def up_to_date(self):
         """Returns True if this Searcher represents the latest version of the
         index, for backends that support versioning.
         self.is_closed = True
 
     def avg_field_length(self, fieldname, default=None):
-        if not self.ixreader.schema[fieldname].scorable:
+        if not self.schema[fieldname].scorable:
             return default
-        return self.ixreader.field_length(fieldname) / (self._doccount or 1)
+        return self.field_length(fieldname) / (self._doccount or 1)
 
     def reader(self):
         """Returns the underlying :class:`~whoosh.reading.IndexReader`.
         
         return c
     
-    def docs_for_query(self, q, leafs=True):
-        if self.subsearchers and leafs:
-            for s, offset in self.subsearchers:
-                for docnum in q.docs(s):
-                    yield docnum + offset
-        else:
-            for docnum in q.docs(self):
-                yield docnum
-
     def key_terms(self, docnums, fieldname, numterms=5,
                   model=classify.Bo1Model, normalize=True):
         """Returns the 'numterms' most important terms from the documents
               if (not comb) or docnum in comb]
         docset = set(docnum for _, docnum in ls)
         ls.sort(key=lambda x: (0 - x[0], x[1]))
-        return Results(self, q, ls, docset, runtime=now() - t)
+        return Results(self, q, ls, docset, runtime=now() - t, filter=filter)
 
     def define_facets(self, name, qs, save=False):
         def doclists_for_searcher(s):
                                         counts=counts)
         return groups
     
+    def docs_for_query(self, q, leafs=True):
+        if self.subsearchers and leafs:
+            for s, offset in self.subsearchers:
+                for docnum in q.docs(s):
+                    yield docnum + offset
+        else:
+            for docnum in q.docs(self):
+                yield docnum
+    
     def search(self, q, limit=10, sortedby=None, reverse=False, groupedby=None,
                optimize=True, scored=True, filter=None, mask=None,
                collector=None):
             collector.scored = scored
             collector.reverse = reverse
         
+        if filter:
+            filter = self._filter_to_comb(filter)
+        if mask:
+            mask = self._filter_to_comb(mask)
+        
         return collector.search(self, q, allow=filter, restrict=mask)
         
 
         if self.limit and self.limit > searcher.doc_count_all():
             self.limit = None
         
-        self._allow = None
-        self._restrict = None
-        if allow:
-            self._allow = self._searcher._filter_to_comb(allow)
-        if restrict:
-            self._restrict = self._searcher._filter_to_comb(restrict)
+        self._allow = allow
+        self._restrict = restrict
             
         if self.timelimit:
             self.timer = threading.Timer(self.timelimit, self._timestop)
         list of matched documents.
         """
         
+        docset = self.docset
         offset = self.doc_offset
         limit = self.limit
         items = self._items
             if restrict and offsetid in restrict:
                 continue
             
+            docset.add(offsetid)
+            
             if keyfns:
                 for name, keyfn in keyfns.iteritems():
                     if name not in self.groups:
         ``None``.
         """
         
-        docset = self.docset
-        
         # Can't use quality optimizations if the matcher doesn't support them
         usequality = usequality and matcher.supports_block_quality()
         replace = self.replace
             id = matcher.id()
             offsetid = id + offset
             
-            if not usequality:
-                docset.add(offsetid)
-            
             # If we're using quality optimizations, check whether the current
             # posting has higher quality than the minimum before yielding it.
             if usequality:
         
         docset = self.docset or None
         return Results(self._searcher, self._q, self.items(), docset,
-                       groups=self.groups, runtime=runtime)
+                       groups=self.groups, runtime=runtime, filter=self._allow,
+                       mask=self._restrict)
 
 
 class TermTrackingCollector(Collector):
     that position in the results.
     """
 
-    def __init__(self, searcher, q, top_n, docset, groups=None, runtime=-1):
+    def __init__(self, searcher, q, top_n, docset, groups=None, runtime=-1,
+                 filter=None, mask=None):
         """
         :param searcher: the :class:`Searcher` object that produced these
             results.
         self.docset = docset
         self._groups = groups or {}
         self.runtime = runtime
+        self._filter = filter
+        self._mask = mask
         self._terms = None
         
         self.fragmenter = highlight.ContextFragmenter()
         return self._groups[name]
     
     def _load_docs(self):
-        self.docset = set(self.searcher.docs_for_query(self.q))
+        filter = self._filter
+        mask = self._mask
+        gen = self.searcher.docs_for_query(self.q)
+        if filter or mask:
+            docset = set()
+            for docnum in gen:
+                if filter and docnum not in filter:
+                    continue
+                if mask and docnum in mask:
+                    continue
+                docset.add(docnum)
+        else:
+            docset = set(gen)
+        self.docset = docset
 
     def has_exact_length(self):
         """True if this results object already knows the exact number of
         """
         
         return self.__class__(self.searcher, self.q, self.top_n[:],
-                              copy.copy(self.docset), runtime=self.runtime)
+                              copy.copy(self.docset), runtime=self.runtime,
+                              filter=self._filter, mask=self._mask)
 
     def score(self, n):
         """Returns the score for the document at the Nth position in the list

File tests/test_results.py

         assert_equal([hit["id"] for hit in r1c], [2, 3, 4])
         assert_equal(r1c.scored_length(), 3)
 
+def test_extend_filtered():
+    schema = fields.Schema(id=fields.STORED, text=fields.TEXT(stored=True))
+    ix = RamStorage().create_index(schema)
+    w = ix.writer()
+    w.add_document(id=1, text=u"alfa bravo charlie")
+    w.add_document(id=2, text=u"bravo charlie delta")
+    w.add_document(id=3, text=u"juliet delta echo")
+    w.add_document(id=4, text=u"delta bravo alfa")
+    w.add_document(id=5, text=u"foxtrot sierra tango")
+    w.commit()
+    
+    hits = lambda result: [hit["id"] for hit in result]
+    
+    with ix.searcher() as s:
+        r1 = s.search(query.Term("text", u"alfa"), filter=set([1, 4]))
+        assert_equal(r1._filter, set([1, 4]))
+        assert_equal(len(r1.top_n), 0)
+        
+        r2 = s.search(query.Term("text", u"bravo"))
+        assert_equal(len(r2.top_n), 3)
+        assert_equal(hits(r2), [1, 2, 4])
+        
+        r3 = r1.copy()
+        assert_equal(r3._filter, set([1, 4]))
+        assert_equal(len(r3.top_n), 0)
+        r3.extend(r2)
+        assert_equal(len(r3.top_n), 3)
+        assert_equal(hits(r3), [1, 2, 4])
+
 def test_pages():
     from whoosh.scoring import Frequency