1. Mikhail Korobov
  2. DAWG

Commits

Mikhail Korobov  committed aa8cfb6

Don't use a shared Completer instance because this is not safe. Unfortunately this makes library slower.

  • Participants
  • Parent commits ecc7bd1
  • Branches default

Comments (0)

Files changed (1)

File src/dawg.pyx

View file
 # cython: profile=False
 from __future__ import unicode_literals
 from libcpp.string cimport string
-
+from libcpp.vector cimport vector
 from iostream cimport stringstream, istream, ostream, ifstream
 cimport iostream
 
         )
 
 
+cdef void init_completer(Completer& completer, Dictionary& dic, Guide& guide):
+    completer.set_dic(dic)
+    completer.set_guide(guide)
+
+
 cdef class CompletionDAWG(DAWG):
     """
     DAWG with key completion support.
     """
     cdef Guide guide
-    cdef Completer* completer
 
     def __init__(self, arg=None, input_is_sorted=False):
         super(CompletionDAWG, self).__init__(arg, input_is_sorted)
         if not _guide_builder.Build(self.dawg, self.dct, &self.guide):
             raise Error("Error building completion information")
-        if not self.completer:
-            self.completer = new Completer(self.dct, self.guide)
 
     def __dealloc__(self):
         self.guide.Clear()
-        if self.completer:
-            del self.completer
 
     cpdef list keys(self, unicode prefix=""):
         cdef bytes b_prefix = prefix.encode('utf8')
         if not self.dct.Follow(b_prefix, &index):
             return res
 
-        self.completer.Start(index, b_prefix)
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
 
-        while self.completer.Next():
-            key = (<char*>self.completer.key()).decode('utf8')
+        while completer.Next():
+            key = (<char*>completer.key()).decode('utf8')
             res.append(key)
 
         return res
         if not self.dct.Follow(b_prefix, &index):
             return
 
-        self.completer.Start(index, b_prefix)
-        while self.completer.Next():
-            key = (<char*>self.completer.key()).decode('utf8')
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
+
+        while completer.Next():
+            key = (<char*>completer.key()).decode('utf8')
             yield key
 
 
             self.dct.Clear()
             raise IOError("Invalid data format: can't load _dawg.Guide")
 
-        if self.completer:
-            del self.completer
-        self.completer = new Completer(self.dct, self.guide)
-
         return self
 
 
                 self.dct.Clear()
                 raise IOError("Invalid data format: can't load _dawg.Guide")
 
-            if self.completer:
-                del self.completer
-            self.completer = new Completer(self.dct, self.guide)
-
         finally:
             stream.close()
 
         cdef BaseType index, prev_index, completer_index
         cdef char* key
 
-        self.completer.Start(self.dct.root())
-        while self.completer.Next():
-            key = <char*>self.completer.key()
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(self.dct.root())
+
+        while completer.Next():
+            key = <char*>completer.key()
 
             index = self.dct.root()
 
-            for i in range(self.completer.length()):
+            for i in range(completer.length()):
                 prev_index = index
                 self.dct.Follow(&(key[i]), 1, &index)
                 transitions.add(
 # The following symbol is not allowed in utf8 so it is safe to use
 # as a separator between utf8-encoded string and binary payload.
 # It has drawbacks however: sorting of utf8-encoded keys changes:
-# (foo' becomes greater than 'foox' because strings are compared as
+# ('foo' becomes greater than 'foox' because strings are compared as
 # 'foo<sep>' and 'foox<sep>' and ord(<sep>)==255 is greater than
 # ord(<any other character>).
 # DEF PAYLOAD_SEPARATOR = b'\xff'
 
     cdef bytes _b_payload_separator
     cdef CharType _c_payload_separator
+    cdef Completer* _completer
 
     def __init__(self, arg=None, input_is_sorted=False, bytes payload_separator=PAYLOAD_SEPARATOR):
         """
         self._c_payload_separator = <unsigned int>ord(payload_separator)
 
         keys = (self._raw_key(d[0], d[1]) for d in arg)
-
         super(BytesDAWG, self).__init__(keys, input_is_sorted)
 
+        self._update_completer()
+
+    def __dealloc__(self):
+        if self._completer:
+            del self._completer
 
     cpdef bytes _raw_key(self, unicode key, bytes payload):
         cdef bytes b_key = key.encode('utf8')
         cdef bytes encoded_payload = b2a_base64(payload)
         return b_key + self._b_payload_separator + encoded_payload
 
+    cdef _update_completer(self):
+        if self._completer:
+            del self._completer
+        self._completer = new Completer(self.dct, self.guide)
+
+    def load(self, path):
+        res = super(BytesDAWG, self).load(path)
+        self._update_completer()
+        return res
+
+    cpdef frombytes(self, bytes data):
+        res = super(BytesDAWG, self).frombytes(data)
+        self._update_completer()
+        return res
+
     cpdef bint b_has_key(self, bytes key) except -1:
         cdef BaseType index
         return self._follow_key(key, &index)
         return self.b_get_value(b_key)
 
     cdef list _value_for_index(self, BaseType index):
-        cdef list res = []
-        cdef int _len
-        cdef b64_decode.decoder _b64_decoder
-        cdef char[MAX_VALUE_SIZE] _b64_decoder_storage
-        cdef bytes payload
 
-        self.completer.Start(index)
-        while self.completer.Next():
-            _b64_decoder.init()
-            _len = _b64_decoder.decode(
-                self.completer.key(),
-                self.completer.length(),
-                _b64_decoder_storage
+        # We want to use shared Completer instance because allocating
+        # a Completer makes this function (and thus __getitem__) 2x slower.
+        # This could be not thread-safe; GIL helps us, but we should be careful
+        # not to occasionally switch to an another thread by iteracting
+        # with Python interpreter in any way (switch happens
+        # between bytecode instructions).
+
+        cdef int key_len
+        cdef b64_decode.decoder b64_decoder
+        cdef char[MAX_VALUE_SIZE] b64_decoder_storage
+        cdef vector[string] results
+
+        self._completer.Start(index)
+
+        while self._completer.Next():
+            b64_decoder.init()
+            key_len = b64_decoder.decode(
+                self._completer.key(),
+                self._completer.length(),
+                b64_decoder_storage
             )
-            payload = _b64_decoder_storage[:_len]
-            res.append(payload)
+            results.push_back(string(b64_decoder_storage, key_len))
 
-        return res
+        return results
 
     cpdef list b_get_value(self, bytes key):
         cdef BaseType index
         cdef b64_decode.decoder _b64_decoder
         cdef char[MAX_VALUE_SIZE] _b64_decoder_storage
 
-        self.completer.Start(index, b_prefix)
-        while self.completer.Next():
-            raw_key = <char*>self.completer.key()
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
 
-            for i in range(0, self.completer.length()):
+        while completer.Next():
+            raw_key = <char*>completer.key()
+
+            for i in range(0, completer.length()):
                 if raw_key[i] == self._c_payload_separator:
                     break
 
             raw_value = &(raw_key[i])
-            raw_value_len = self.completer.length() - i
+            raw_value_len = completer.length() - i
 
             _b64_decoder.init()
             _len = _b64_decoder.decode(raw_value, raw_value_len, _b64_decoder_storage)
         cdef b64_decode.decoder _b64_decoder
         cdef char[MAX_VALUE_SIZE] _b64_decoder_storage
 
-        self.completer.Start(index, b_prefix)
-        while self.completer.Next():
-            raw_key = <char*>self.completer.key()
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
 
-            for i in range(0, self.completer.length()):
+        while completer.Next():
+            raw_key = <char*>completer.key()
+
+            for i in range(0, completer.length()):
                 if raw_key[i] == self._c_payload_separator:
                     break
 
             raw_value = &(raw_key[i])
-            raw_value_len = self.completer.length() - i
+            raw_value_len = completer.length() - i
 
             _b64_decoder.init()
             _len = _b64_decoder.decode(raw_value, raw_value_len, _b64_decoder_storage)
         if not self.dct.Follow(b_prefix, &index):
             return res
 
-        self.completer.Start(index, b_prefix)
-        while self.completer.Next():
-            raw_key = <char*>self.completer.key()
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
 
-            for i in range(0, self.completer.length()):
+        while completer.Next():
+            raw_key = <char*>completer.key()
+
+            for i in range(0, completer.length()):
                 if raw_key[i] == self._c_payload_separator:
                     break
 
         if not self.dct.Follow(b_prefix, &index):
             return
 
-        self.completer.Start(index, b_prefix)
-        while self.completer.Next():
-            raw_key = <char*>self.completer.key()
+        cdef Completer completer
+        init_completer(completer, self.dct, self.guide)
+        completer.Start(index, b_prefix)
 
-            for i in range(0, self.completer.length()):
+        while completer.Next():
+            raw_key = <char*>completer.key()
+
+            for i in range(0, completer.length()):
                 if raw_key[i] == self._c_payload_separator:
                     break