Commits

Mikhail Korobov committed 129567f

basic trie wrappers

  • Participants
  • Parent commits 598f62c

Comments (0)

Files changed (5)

File src/chat_trie.pxd

 cdef extern from "../hat-trie/src/hat-trie.h":
 
-    cdef int value_t
-    cdef int size_t
+    ctypedef int value_t
+    ctypedef int size_t
 
     ctypedef struct hattrie_t:
         pass
 
-    hattrie_t* hattrie_create (void)             # Create an empty hat-trie.
+    hattrie_t* hattrie_create ()                 # Create an empty hat-trie.
     void       hattrie_free   (hattrie_t*)       # Free all memory used by a trie.
-    hattrie_t* hattrie_dup    (const hattrie_t*) # Duplicate an existing trie.
+    hattrie_t* hattrie_dup    (hattrie_t*)       # Duplicate an existing trie.
     void       hattrie_clear  (hattrie_t*)       # Remove all entries.
 
 
     ctypedef struct hattrie_iter_t:
         pass
 
-    hattrie_iter_t* hattrie_iter_begin     (const hattrie_t*)
+    hattrie_iter_t* hattrie_iter_begin     (hattrie_t*)
     void            hattrie_iter_next      (hattrie_iter_t*)
-    bool            hattrie_iter_finished  (hattrie_iter_t*)
+    bint            hattrie_iter_finished  (hattrie_iter_t*)
     void            hattrie_iter_free      (hattrie_iter_t*)
     char*           hattrie_iter_key       (hattrie_iter_t*, size_t* len)
     value_t*        hattrie_iter_val       (hattrie_iter_t*)

File src/hat_trie.pyx

+cimport chat_trie
+
+cdef class BaseTrie:
+
+    cdef chat_trie.hattrie_t* _trie
+
+    def __cinit__(self):
+        self._trie = chat_trie.hattrie_create()
+
+    def __dealloc__(self):
+        if self._trie:
+            chat_trie.hattrie_free(self._trie)
+
+    def __getitem__(self, bytes key):
+        return self._getitem(key)
+
+    cdef int _getitem(self, bytes key) except -1:
+        cdef char* c_key = key
+        cdef chat_trie.value_t* value_ptr = chat_trie.hattrie_tryget(self._trie, c_key, len(c_key))
+        if value_ptr == NULL:
+            raise KeyError(key)
+        return value_ptr[0]
+
+    def __setitem__(self, bytes key, int value):
+        self._setitem(key, value)
+
+    cdef void _setitem(self, bytes key, chat_trie.value_t value):
+        chat_trie.hattrie_get(self._trie, key, len(key))[0] = value
+
+    def __contains__(self, bytes key):
+        return self._contains(key)
+
+    cdef bint _contains(self, bytes key):
+        cdef chat_trie.value_t* value_ptr = chat_trie.hattrie_tryget(self._trie, key, len(key))
+        return value_ptr != NULL
+
+
+cdef class Trie(BaseTrie):
+    cdef unicode encoding
+
+    def __init__(self, encoding='latin1'):
+        self.encoding = encoding
+
+    def __getitem__(self, unicode key):
+        return self._getitem(key.encode(self.encoding))
+
+    def __contains__(self, unicode key):
+        return self._contains(key.encode(self.encoding))
+
+    def __setitem__(self, unicode key, int value):
+        self._setitem(key.encode(self.encoding), value)

File tests/__init__.py

+# -*- coding: utf-8 -*-
+from __future__ import absolute_import

File tests/test_base_trie.py

+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+import string
+import random
+
+import pytest
+import hat_trie
+
+def test_get_set():
+    trie = hat_trie.BaseTrie()
+    trie[b'foo'] = 5
+    trie[b'bar'] = 10
+    assert trie[b'foo'] == 5
+    assert trie[b'bar'] == 10
+
+    with pytest.raises(KeyError):
+        trie[b'f']
+
+    with pytest.raises(KeyError):
+        trie[b'foob']
+
+    with pytest.raises(KeyError):
+        trie[b'x']
+
+    non_ascii_key = 'вася'.encode('cp1251')
+    trie[non_ascii_key] = 20
+    assert trie[non_ascii_key] == 20
+
+def test_contains():
+    trie = hat_trie.BaseTrie()
+    assert b'foo' not in trie
+    trie[b'foo'] = 5
+    assert b'foo' in trie
+    assert b'f' not in trie
+
+
+@pytest.mark.parametrize(("encoding",), [['cp1251'], ['utf8']])
+def test_get_set_fuzzy(encoding):
+    russian = 'абвгдеёжзиклмнопрстуфхцчъыьэюя'
+    alphabet = russian.upper() + string.ascii_lowercase
+    words = list(set([
+        "".join([random.choice(alphabet) for x in range(random.randint(2,10))])
+        for y in range(1000)
+    ]))
+
+    words = [w.encode(encoding) for w in words]
+
+    trie = hat_trie.BaseTrie()
+
+    enumerated_words = list(enumerate(words))
+
+    for index, word in enumerated_words:
+        trie[word] = index
+
+    random.shuffle(enumerated_words)
+    for index, word in enumerated_words:
+        assert word in trie, word
+        assert trie[word] == index, (word, index)
+

File tests/test_trie.py

+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+import string
+import random
+
+import pytest
+import hat_trie
+
+def test_get_set():
+    trie = hat_trie.Trie('cp1251')
+    trie['foo'] = 5
+    trie['bar'] = 10
+    assert trie['foo'] == 5
+    assert trie['bar'] == 10
+
+    with pytest.raises(KeyError):
+        trie['f']
+
+    with pytest.raises(KeyError):
+        trie['foob']
+
+    with pytest.raises(KeyError):
+        trie['x']
+
+    non_ascii_key = 'вася'
+    trie[non_ascii_key] = 20
+    assert trie[non_ascii_key] == 20
+
+def test_contains():
+    trie = hat_trie.Trie('1251')
+    assert 'foo' not in trie
+    trie['foo'] = 5
+    assert 'foo' in trie
+    assert 'f' not in trie
+
+
+@pytest.mark.parametrize(("encoding",), [['cp1251'], ['utf8']])
+def test_get_set_fuzzy(encoding):
+    russian = 'абвгдеёжзиклмнопрстуфхцчъыьэюя'
+    alphabet = russian.upper() + string.ascii_lowercase
+    words = list(set([
+        "".join([random.choice(alphabet) for x in range(random.randint(2,10))])
+        for y in range(1000)
+    ]))
+
+    trie = hat_trie.Trie(encoding)
+
+    enumerated_words = list(enumerate(words))
+
+    for index, word in enumerated_words:
+        trie[word] = index
+
+    random.shuffle(enumerated_words)
+    for index, word in enumerated_words:
+        assert word in trie, word
+        assert trie[word] == index, (word, index)
+