Commits

Mikhail Korobov  committed cd3a3c0

New methods for Trie: __getitem__, get, items, iteritems.

  • Participants
  • Parent commits 6d8c34b

Comments (0)

Files changed (3)

     False
 
 Each key is assigned an unique ID from 0 to (n - 1), where n is the
-number of keys; you can use this ID to store a value in a
-separate structure (e.g. python list)::
+number of keys::
 
     >>> trie.key_id(u'key2')
     1
+    >>> trie[u'key2']  # alternative syntax
+    1
+
+Note that you can't assign a value to a ``marisa_trie.Trie`` key,
+but can use the returned ID to store a value in a separate data structure
+(e.g. in a python list or numpy array).
 
 Key can be reconstructed from the ID::
 
 
 (iterator version ``.iterkeys(prefix)`` is also available).
 
+Use ``items()`` method to return all (key, ID) pairs::
+
+    >>> trie.items()
+    [(u'key1', 0), (u'key12', 2), (u'key2', 1)]
+
+Filter them by prefix::
+
+    >>> trie.items(u'key1')
+    [(u'key1', 0), (u'key12', 2)]
+
+(iterator version ``.iteritems(prefix)`` is also available).
+
 marisa_trie.RecordTrie
 ----------------------
 

File src/marisa_trie.pyx

             raise KeyError(key)
         return res
 
+    def __getitem__(self, unicode key):
+        return self.key_id(key)
+
+    def get(self, key, default=None):
+        """
+        Return a key id for a given key or ``default`` if the key is not found.
+        """
+        cdef bytes b_key
+        cdef int res
+
+        if isinstance(key, unicode):
+            b_key = (<unicode>key).encode('utf8')
+        else:
+            b_key = key
+
+        res = self._key_id(b_key)
+        if res == -1:
+            return default
+        return res
+
     cpdef restore_key(self, int index):
         """
         Return a key given its index (obtained by ``key_id`` method).
             res.append(_get_key(ag))
         return res
 
+    def iteritems(self, unicode prefix=""):
+        """
+        Return an iterator over items that have a prefix ``prefix``.
+        """
+        cdef bytes b_prefix = prefix.encode('utf8')
+        cdef agent.Agent ag
+        ag.set_query(b_prefix)
+
+        while self._trie.predictive_search(ag):
+            yield _get_key(ag), ag.key().id()
+
+    def items(self, unicode prefix=""):
+        # inlined for speed
+        cdef list res = []
+        cdef bytes b_prefix = prefix.encode('utf8')
+        cdef agent.Agent ag
+        ag.set_query(b_prefix)
+
+        while self._trie.predictive_search(ag):
+            res.append((_get_key(ag), ag.key().id()))
+
+        return res
+
 
 # This symbol is not allowed in utf8 so it is safe to use
 # as a separator between utf8-encoded string and binary payload.

File tests/test_trie.py

     assert 'foo' in trie
     assert 'f' not in trie
 
+
 def test_key_id():
     words = ['foo', 'bar', 'f']
     trie = marisa_trie.Trie(words)
         key_id = trie.key_id(word)
         assert trie.restore_key(key_id) == word
 
-
     key_ids = [trie.key_id(word) for word in words]
     non_existing_id = max(key_ids) + 1
 
     with pytest.raises(KeyError):
         print(trie.key_id('fo'))
 
+
+def test_getitem():
+    words = ['foo', 'bar', 'f']
+    trie = marisa_trie.Trie(words)
+    for word in words:
+        key_id = trie[word]
+        assert trie.restore_key(key_id) == word
+
+    key_ids = [trie[word] for word in words]
+    non_existing_id = max(key_ids) + 1
+
+    with pytest.raises(KeyError):
+        trie.restore_key(non_existing_id)
+
+    with pytest.raises(KeyError):
+        print(trie['fo'])
+
+
+def test_get():
+    words = ['foo', 'bar', 'f']
+    trie = marisa_trie.Trie(words)
+    for word in words:
+        key_id = trie.get(word)
+        assert trie.restore_key(key_id) == word
+
+        key_id = trie.get(word.encode('utf8'))
+        assert trie.restore_key(key_id) == word
+
+        key_id = trie.get(word, 'default value')
+        assert trie.restore_key(key_id) == word
+
+    assert trie.get('non_existing_key') is None
+    assert trie.get(b'non_existing_bytes_key') is None
+    assert trie.get('non_existing_key', 'default value') == 'default value'
+    assert trie.get(b'non_existing_bytes_key', 'default value') == 'default value'
+
+
 def test_saveload():
     fd, fname = tempfile.mkstemp()
 
     for word in words:
         assert word in trie2
 
+
 def test_dumps_loads():
     words = get_random_words(1000)
     trie = marisa_trie.Trie(words)
         assert word in trie2
         assert trie2.key_id(word) == trie.key_id(word)
 
+
 def test_pickling():
     words = get_random_words(1000)
     trie = marisa_trie.Trie(words)
         prefix = key[:5]
         assert trie.keys(prefix) == list(trie.iterkeys(prefix))
 
+
+def test_items():
+    keys = ['foo', 'f', 'foobar', 'bar']
+    trie = marisa_trie.Trie(keys)
+    items = trie.items()
+    assert set(items) == set(zip(keys, (trie[k] for k in keys)))
+
+
+def test_items_prefix():
+    keys = ['foo', 'f', 'foobar', 'bar']
+    trie = marisa_trie.Trie(keys)
+    assert set(trie.items('fo')) == set([
+        ('foo', trie['foo']),
+        ('foobar', trie['foobar']),
+    ])
+
+
+def test_iteritems():
+    keys = get_random_words(1000)
+    trie = marisa_trie.Trie(keys)
+    assert trie.items() == list(trie.iteritems())
+
+    for key in keys:
+        prefix = key[:5]
+        assert trie.items(prefix) == list(trie.iteritems(prefix))
+
+
 def test_has_keys_with_prefix():
     empty_trie = marisa_trie.Trie()
     assert empty_trie.has_keys_with_prefix('') == False