Commits

Mikhail Korobov  committed ba6d501

iterkeys and iteritems methods for BytesTrie and RecordTrie

  • Participants
  • Parent commits bf80ca8

Comments (0)

Files changed (4)

File bench/speed.py

 if __name__ == '__main__':
     benchmark()
     #profiling()
-    print('\n~~~~~~~~~~~~~~\n')
+    print('\n~~~~~~~~~~~~~~\n')

File src/marisa_trie.pyx

          Return an iterator over keys that have a prefix ``prefix``.
          """
          cdef agent.Agent ag
-         cdef unicode key
          cdef bytes b_prefix = prefix.encode('utf8')
          ag.set_query(b_prefix)
 
         return res
 
     cpdef list items(self, unicode prefix=""):
+        # copied from iteritems for speed
         cdef bytes b_prefix = prefix.encode('utf8')
         cdef bytes value
         cdef unicode key
             )
         return res
 
+    def iteritems(self, unicode prefix=""):
+        cdef bytes b_prefix = prefix.encode('utf8')
+        cdef bytes value
+        cdef unicode key
+        cdef unsigned char* raw_key
+        cdef int i, value_len
+
+        cdef agent.Agent ag
+        ag.set_query(b_prefix)
+
+        while self._trie.predictive_search(ag):
+            raw_key = <unsigned char*>ag.key().ptr()
+
+            for i in range(0, ag.key().length()):
+                if raw_key[i] == self._c_value_separator:
+                    break
+
+            key = raw_key[:i].decode('utf8')
+            value = raw_key[i+1:ag.key().length()]
+
+            yield key, value
+
+
     cpdef list keys(self, unicode prefix=""):
+        # copied from iterkeys for speed
         cdef bytes b_prefix = prefix.encode('utf8')
         cdef unicode key
         cdef unsigned char* raw_key
                     break
         return res
 
+    def iterkeys(self, unicode prefix=""):
+        cdef bytes b_prefix = prefix.encode('utf8')
+        cdef unicode key
+        cdef unsigned char* raw_key
+        cdef int i
+
+        cdef agent.Agent ag
+        ag.set_query(b_prefix)
+
+        while self._trie.predictive_search(ag):
+            raw_key = <unsigned char*>ag.key().ptr()
+
+            for i in range(0, ag.key().length()):
+                if raw_key[i] == self._c_value_separator:
+                    yield raw_key[:i].decode('utf8')
+                    break
+
 
 cdef class _UnpackTrie(BytesTrie):
 
         cdef list items = BytesTrie.items(self, prefix)
         return [(key, self._unpack(val)) for (key, val) in items]
 
+    def iteritems(self, unicode prefix=""):
+        return ((key, self._unpack(val)) for key, val in BytesTrie.iteritems(self, prefix))
 
 
 cdef class RecordTrie(_UnpackTrie):

File tests/test_payload.py

         assert trie.keys('food') == []
         assert trie.keys('bar') == []
 
+    def test_iterkeys(self):
+        keys = get_random_words(1000)
+        values = get_random_binary(1000)
+
+        trie = marisa_trie.BytesTrie(zip(keys, values))
+        assert trie.keys() == list(trie.iterkeys())
+
+        for key in keys:
+            prefix = key[:5]
+            assert trie.keys(prefix) == list(trie.iterkeys(prefix))
+
     def test_items(self):
         data = [
             ('fo',  b'y'),
         assert trie.items('food') == []
         assert trie.items('bar') == []
 
+    def test_iteritems(self):
+        keys = get_random_words(1000)
+        values = get_random_binary(1000)
+
+        trie = marisa_trie.BytesTrie(zip(keys, values))
+        assert trie.items() == list(trie.iteritems())
+
+        for key in keys:
+            prefix = key[:5]
+            assert trie.items(prefix) == list(trie.iteritems(prefix))
+
     def test_pickling(self):
         trie = marisa_trie.BytesTrie([
             ('foo', b'foo'),
 
         assert set(trie.items()) == set(data)
 
+    def test_iteritems(self):
+        fmt, data = self.data()
+        trie = marisa_trie.RecordTrie(fmt, data)
+        assert trie.items() == list(trie.iteritems())
+
+        for key, value in data:
+            prefix = key[:5]
+            assert trie.items(prefix) == list(trie.iteritems(prefix))
+
+
     def test_prefixes(self):
         trie = marisa_trie.RecordTrie(str("<H"), [
             ('foo', [1]),

File tests/test_trie.py

     trie = marisa_trie.Trie(['foo', 'f', 'bar'])
     assert len(trie) == 3
 
+
 def test_prefixes():
     trie = marisa_trie.Trie(['foo', 'f', 'foobar', 'bar'])
     assert trie.prefixes('foobar') == ['f', 'foo', 'foobar']
 
     assert list(trie.iter_prefixes('foobar')) == ['f', 'foo', 'foobar']
 
+
 def test_keys():
     keys = ['foo', 'f', 'foobar', 'bar']
     trie = marisa_trie.Trie(keys)
     assert set(trie.keys()) == set(keys)
 
+
 def test_keys_prefix():
     keys = ['foo', 'f', 'foobar', 'bar']
     trie = marisa_trie.Trie(keys)
     assert trie.keys('foobarz') == []
 
 
+def test_iterkeys():
+    keys = get_random_words(1000)
+    trie = marisa_trie.Trie(keys)
+    assert trie.keys() == list(trie.iterkeys())
+
+    for key in keys:
+        prefix = key[:5]
+        assert trie.keys(prefix) == list(trie.iterkeys(prefix))
+
+
 def test_invalid_file():
     try:
         marisa_trie.Trie().load(__file__)
 #
 #    # trie is frozen, no new keys are allowed
 #    with pytest.raises(KeyError):
-#        trie['z'] = 200
+#        trie['z'] = 200