1. Mikhail Korobov
  2. DAWG-Python

Commits

Mikhail Korobov  committed 3dc0120

BytesDAWG and RecordDAWG are implemented; dev_data is split to small and large

  • Participants
  • Parent commits e63ee74
  • Branches default

Comments (0)

Files changed (25)

File bench/speed.py

View file
         #print(e)
 
 def load_dawg():
-    return dawg_python.DAWG().load(data_path('dawg.dawg'))
+    return dawg_python.DAWG().load(data_path('large', 'dawg.dawg'))
 
 def load_bytes_dawg():
-    return dawg_python.BytesDAWG().load(data_path('bytes_dawg.dawg'))
+    return dawg_python.BytesDAWG().load(data_path('large', 'bytes_dawg.dawg'))
 
 def load_record_dawg():
-    return dawg_python.RecordDAWG().load(data_path('record_dawg.dawg'))
+    return dawg_python.RecordDAWG(str('<H')).load(data_path('large', 'record_dawg.dawg'))
 
 def load_int_dawg():
-    return dawg_python.IntDAWG().load(data_path('int_dawg.dawg'))
+    return dawg_python.IntDAWG().load(data_path('large', 'int_dawg.dawg'))
 
 def benchmark():
     print('\n====== Benchmarks (100k unique unicode words) =======\n')

File bench/utils.py

View file
     'dev_data',
 )
 
-def data_path(filename):
+def data_path(*args):
     """
     Returns a path to dev data
     """
-    return os.path.join(DEV_DATA_PATH, filename)
+    return os.path.join(DEV_DATA_PATH, *args)
 
 def words100k():
     zip_name = data_path('words100k.txt.zip')

File dawg_python/dawgs.py

View file
 # -*- coding: utf-8 -*-
-from __future__ import absolute_import
+from __future__ import absolute_import, unicode_literals
+
+import struct
+from binascii import a2b_base64
+
 from . import wrapper
 
 class DAWG(object):
         self.dct = wrapper.Dictionary.load(path)
         return self
 
+    def _has_value(self, index):
+        return self.dct.has_value(index)
+
     def _similar_keys(self, current_prefix, key, index, replace_chars):
 
         res = []
             word_pos += 1
 
         else:
-            if self.dct.has_value(index):
+            if self._has_value(index):
                 found_key = current_prefix + key[start_pos:]
                 res.insert(0, found_key)
 
         self.completer = wrapper.Completer(self.dct, self.guide)
         return self
 
+
+# This symbol is not allowed in utf8 so it is safe to use
+# as a separator between utf8-encoded string and binary payload.
+PAYLOAD_SEPARATOR = b'\xff'
+MAX_VALUE_SIZE = 32768
+
 class BytesDAWG(CompletionDAWG):
-    def __init__(self):
-        raise NotImplementedError
+    """
+    DAWG that is able to transparently store extra binary payload in keys;
+    there may be several payloads for the same key.
+
+    In other words, this class implements read-only DAWG-based
+    {unicode -> list of bytes objects} mapping.
+    """
+
+    def __contains__(self, key):
+        if not isinstance(key, bytes):
+            key = key.encode('utf8')
+        return bool(self._follow_key(key))
+
+#    def b_has_key(self, key):
+#        return bool(self._follow_key(key))
+
+    def __getitem__(self, key):
+        res = self.get(key)
+        if res is None:
+            raise KeyError(key)
+        return res
+
+    def get(self, key, default=None):
+        """
+        Returns a list of payloads (as byte objects) for a given key
+        or ``default`` if the key is not found.
+        """
+        if not isinstance(key, bytes):
+            key = key.encode('utf8')
+
+        return self.b_get_value(key) or default
+
+    def _follow_key(self, b_key):
+        index = self.dct.follow_bytes(b_key, self.dct.root())
+        if not index:
+            return False
+
+        index = self.dct.follow_bytes(PAYLOAD_SEPARATOR, index)
+        if not index:
+            return False
+
+        return index
+
+    def _value_for_index(self, index):
+        res = []
+
+        self.completer.start(index)
+        while self.completer.next():
+            # a2b_base64 doesn't support bytearray in python 2.6
+            # so it is converted (and copied) to bytes
+            b64_data = bytes(self.completer.key)
+            res.append(a2b_base64(b64_data))
+
+        return res
+
+    def b_get_value(self, b_key):
+        index = self._follow_key(b_key)
+        if not index:
+            return []
+        return self._value_for_index(index)
+
+    def keys(self, prefix=""):
+        if not isinstance(prefix, bytes):
+            prefix = prefix.encode('utf8')
+        res = []
+
+        index = self.dct.root()
+
+        if prefix:
+            index = self.dct.follow_bytes(prefix, index)
+            if not index:
+                return res
+
+        self.completer.start(index, prefix)
+        while self.completer.next():
+            payload_idx = self.completer.key.index(PAYLOAD_SEPARATOR)
+            u_key = self.completer.key[:payload_idx].decode('utf8')
+            res.append(u_key)
+        return res
+
+    def items(self, prefix=""):
+        if not isinstance(prefix, bytes):
+            prefix = prefix.encode('utf8')
+        res = []
+
+        index = self.dct.root()
+        if prefix:
+            index = self.dct.follow_bytes(prefix, index)
+            if not index:
+                return res
+
+        self.completer.start(index, prefix)
+        while self.completer.next():
+            key, value = self.completer.key.split(PAYLOAD_SEPARATOR)
+            res.append(
+                (key.decode('utf8'), a2b_base64(bytes(value))) # python 2.6 fix
+            )
+
+        return res
+
+
+    def _has_value(self, index):
+        return self.dct.follow_bytes(PAYLOAD_SEPARATOR, index)
+
+    def _similar_items(self, current_prefix, key, index, replace_chars):
+
+        res = []
+        start_pos = len(current_prefix)
+        end_pos = len(key)
+        word_pos = start_pos
+
+        while word_pos < end_pos:
+            b_step = key[word_pos].encode('utf8')
+
+            if b_step in replace_chars:
+                next_index = index
+                b_replace_char, u_replace_char = replace_chars[b_step]
+
+                next_index = self.dct.follow_bytes(b_replace_char, next_index)
+                if next_index:
+                    prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
+                    extra_items = self._similar_items(prefix, key, next_index, replace_chars)
+                    res += extra_items
+
+            index = self.dct.follow_bytes(b_step, index)
+            if not index:
+                break
+            word_pos += 1
+
+        else:
+            index = self.dct.follow_bytes(PAYLOAD_SEPARATOR, index)
+            if index:
+                found_key = current_prefix + key[start_pos:]
+                value = self._value_for_index(index)
+                res.insert(0, (found_key, value))
+
+        return res
+
+    def similar_items(self, key, replaces):
+        """
+        Returns a list of (key, value) tuples for all variants of ``key``
+        in this DAWG according to ``replaces``.
+
+        ``replaces`` is an object obtained from
+        ``DAWG.compile_replaces(mapping)`` where mapping is a dict
+        that maps single-char unicode sitrings to another single-char
+        unicode strings.
+        """
+        return self._similar_items("", key, self.dct.root(), replaces)
+
+
+    def _similar_item_values(self, start_pos, key, index, replace_chars):
+        res = []
+        end_pos = len(key)
+        word_pos = start_pos
+
+        while word_pos < end_pos:
+            b_step = key[word_pos].encode('utf8')
+
+            if b_step in replace_chars:
+                next_index = index
+                b_replace_char, u_replace_char = replace_chars[b_step]
+
+                next_index = self.dct.follow_bytes(b_replace_char, next_index)
+                if next_index:
+                    extra_items = self._similar_item_values(word_pos+1, key, next_index, replace_chars)
+                    res += extra_items
+
+            index = self.dct.follow_bytes(b_step, index)
+            if not index:
+                break
+            word_pos += 1
+
+        else:
+            index = self.dct.follow_bytes(PAYLOAD_SEPARATOR, index)
+            if index:
+                value = self._value_for_index(index)
+                res.insert(0, value)
+
+        return res
+
+    def similar_item_values(self, key, replaces):
+        """
+        Returns a list of values for all variants of the ``key``
+        in this DAWG according to ``replaces``.
+
+        ``replaces`` is an object obtained from
+        ``DAWG.compile_replaces(mapping)`` where mapping is a dict
+        that maps single-char unicode sitrings to another single-char
+        unicode strings.
+        """
+        return self._similar_item_values(0, key, self.dct.root(), replaces)
+
 
 class RecordDAWG(BytesDAWG):
-    def __init__(self):
-        raise NotImplementedError
+    def __init__(self, fmt):
+        super(RecordDAWG, self).__init__()
+        self._struct = struct.Struct(str(fmt))
+        self.fmt = fmt
+
+    def _value_for_index(self, index):
+        value = super(RecordDAWG, self)._value_for_index(index)
+        return [self._struct.unpack(val) for val in value]
+
+    def items(self, prefix=""):
+        res = super(RecordDAWG, self).items(prefix)
+        return [(key, self._struct.unpack(val)) for (key, val) in res]
+

File dawg_python/wrapper.py

View file
     def value(self):
         return self._dic.value(self._last_index)
 
-    def start(self, index, prefix=""):
+    def start(self, index, prefix=b""):
         self.key = bytearray(prefix)
         self._index_stack = [index]
         self._last_index = self._dic.root()

File dev_data/bytes_dawg.dawg

Binary file removed.

File dev_data/completion.dawg

Binary file removed.

File dev_data/dawg.dawg

Binary file removed.

File dev_data/int_dawg.dawg

Binary file removed.

File dev_data/large/bytes_dawg.dawg

Binary file added.

File dev_data/large/dawg.dawg

Binary file added.

File dev_data/large/int_dawg.dawg

Binary file added.

File dev_data/large/record_dawg.dawg

Binary file added.

File dev_data/prediction-record.dawg

Binary file removed.

File dev_data/prediction.dawg

Binary file removed.

File dev_data/record_dawg.dawg

Binary file removed.

File dev_data/small/bytes.dawg

Binary file added.

File dev_data/small/completion.dawg

Binary file added.

File dev_data/small/prediction-record.dawg

Binary file added.

File dev_data/small/prediction.dawg

Binary file added.

File dev_data/small/record.dawg

Binary file added.

File tests/test_dawg.py

View file
     keys = ['f', 'bar', 'foo', 'foobar']
 
     def dawg(self):
-        return dawg_python.CompletionDAWG().load(data_path('completion.dawg'))
+        return dawg_python.CompletionDAWG().load(data_path('small', 'completion.dawg'))
 
     def test_contains(self):
         d = self.dawg()

File tests/test_fuzzy.py

View file
 # -*- coding: utf-8 -*-
-from __future__ import absolute_import
+from __future__ import absolute_import, unicode_literals
 
 import dawg_python
 
 from .utils import words100k, data_path
 
 words = words100k()
-dawg = dawg_python.Dictionary.load(data_path('int_dawg.dawg'))
+dawg = dawg_python.Dictionary.load(data_path('large', 'int_dawg.dawg'))
 
 class TestDictionary(object):
 

File tests/test_payload_dawg.py

View file
+# -*- coding: utf-8 -*-
+from __future__ import absolute_import, unicode_literals
+
+import pytest
+import dawg_python
+from .utils import data_path
+
+class TestBytesDAWG(object):
+
+    DATA = (
+        ('foo', b'data1'),
+        ('bar', b'data2'),
+        ('foo', b'data3'),
+        ('foobar', b'data4')
+    )
+
+    def dawg(self):
+        return dawg_python.BytesDAWG().load(data_path("small", "bytes.dawg"))
+
+    def test_contains(self):
+        d = self.dawg()
+        for key, val in self.DATA:
+            assert key in d
+
+        assert 'food' not in d
+        assert 'x' not in d
+        assert 'fo' not in d
+
+
+    def test_getitem(self):
+        d = self.dawg()
+
+        assert d['foo'] == [b'data1', b'data3']
+        assert d['bar'] == [b'data2']
+        assert d['foobar'] == [b'data4']
+
+
+    def test_getitem_missing(self):
+        d = self.dawg()
+
+        with pytest.raises(KeyError):
+            d['x']
+
+        with pytest.raises(KeyError):
+            d['food']
+
+        with pytest.raises(KeyError):
+            d['foobarz']
+
+        with pytest.raises(KeyError):
+            d['f']
+
+    def test_keys(self):
+        d = self.dawg()
+        assert d.keys() == ['bar', 'foobar', 'foo', 'foo'] # order?
+
+    def test_key_completion(self):
+        d = self.dawg()
+        assert d.keys('fo') == ['foobar', 'foo', 'foo'] # order?
+
+    def test_items(self):
+        d = self.dawg()
+        assert sorted(d.items()) == sorted(self.DATA)
+
+    def test_items_completion(self):
+        d = self.dawg()
+        assert d.items('foob') == [('foobar', b'data4')]
+
+
+class TestRecordDAWG(object):
+
+    STRUCTURED_DATA = (  # payload is (length, vowels count, index) tuple
+        ('foo',     (3, 2, 0)),
+        ('bar',     (3, 1, 0)),
+        ('foo',     (3, 2, 1)),
+        ('foobar',  (6, 3, 0))
+    )
+
+    def dawg(self):
+        path = data_path("small", "record.dawg")
+        return dawg_python.RecordDAWG("=3H").load(path)
+
+    def test_getitem(self):
+        d = self.dawg()
+        assert d['foo'] == [(3, 2, 0), (3, 2, 1)]
+        assert d['bar'] == [(3, 1, 0)]
+        assert d['foobar'] == [(6, 3, 0)]
+
+    def test_getitem_missing(self):
+        d = self.dawg()
+
+        with pytest.raises(KeyError):
+            d['x']
+
+        with pytest.raises(KeyError):
+            d['food']
+
+        with pytest.raises(KeyError):
+            d['foobarz']
+
+        with pytest.raises(KeyError):
+            d['f']
+
+    def test_record_items(self):
+        d = self.dawg()
+        assert sorted(d.items()) == sorted(self.STRUCTURED_DATA)
+
+    def test_record_keys(self):
+        d = self.dawg()
+        assert sorted(d.keys()) == ['bar', 'foo', 'foo', 'foobar',]
+
+    def test_record_keys_prefix(self):
+        d = self.dawg()
+        assert sorted(d.keys('fo')) == ['foo', 'foo', 'foobar']
+        assert d.keys('bar') == ['bar']
+        assert d.keys('barz') == []

File tests/test_prediction.py

View file
 from .utils import data_path
 
 class TestPrediction(object):
-    DATA = ['ЁЖИК', 'ЁЖИКЕ', 'ЁЖ', 'ДЕРЕВНЯ', 'ДЕРЁВНЯ', 'ЕМ', 'ОЗЕРА', 'ОЗЁРА', 'ОЗЕРО']
-    LENGTH_DATA = list(zip(DATA, ((len(w),) for w in DATA)))
 
     REPLACES = dawg_python.DAWG.compile_replaces({'Е': 'Ё'})
 
+    # DATA = ['ЁЖИК', 'ЁЖИКЕ', 'ЁЖ', 'ДЕРЕВНЯ', 'ДЕРЁВНЯ', 'ЕМ', 'ОЗЕРА', 'ОЗЁРА', 'ОЗЕРО']
     SUITE = [
         ('УЖ', []),
         ('ЕМ', ['ЕМ']),
         for it in SUITE
     ]
 
+    def record_dawg(self):
+        path = data_path("small", "prediction-record.dawg")
+        return dawg_python.RecordDAWG(str("=H")).load(path)
+
+
 
     @pytest.mark.parametrize(("word", "prediction"), SUITE)
     def test_dawg_prediction(self, word, prediction):
-        d = dawg_python.DAWG().load(data_path('prediction.dawg'))
+        d = dawg_python.DAWG().load(data_path("small", "prediction.dawg"))
         assert d.similar_keys(word, self.REPLACES) == prediction
 
-#    @pytest.mark.parametrize(("word", "prediction"), SUITE)
-#    def test_record_dawg_prediction(self, word, prediction):
-#        d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
-#        assert d.similar_keys(word, self.REPLACES) == prediction
-#
-#    @pytest.mark.parametrize(("word", "prediction"), SUITE_ITEMS)
-#    def test_record_dawg_items(self, word, prediction):
-#        d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
-#        assert d.similar_items(word, self.REPLACES) == prediction
-#
-#    @pytest.mark.parametrize(("word", "prediction"), SUITE_VALUES)
-#    def test_record_dawg_items_values(self, word, prediction):
-#        d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
-#        assert d.similar_item_values(word, self.REPLACES) == prediction
+    @pytest.mark.parametrize(("word", "prediction"), SUITE)
+    def test_record_dawg_prediction(self, word, prediction):
+        d = self.record_dawg()
+        assert d.similar_keys(word, self.REPLACES) == prediction
+
+    @pytest.mark.parametrize(("word", "prediction"), SUITE_ITEMS)
+    def test_record_dawg_items(self, word, prediction):
+        d = self.record_dawg()
+        assert d.similar_items(word, self.REPLACES) == prediction
+
+    @pytest.mark.parametrize(("word", "prediction"), SUITE_VALUES)
+    def test_record_dawg_items_values(self, word, prediction):
+        d = self.record_dawg()
+        assert d.similar_item_values(word, self.REPLACES) == prediction

File tests/utils.py

View file
     'dev_data',
 )
 
-def data_path(filename):
+def data_path(*args):
     """
     Returns a path to dev data
     """
-    return os.path.join(DEV_DATA_PATH, filename)
+    return os.path.join(DEV_DATA_PATH, *args)
 
 def words100k():
     zip_name = data_path('words100k.txt.zip')