1. Killian De Smedt
  2. py3_protobuf

Commits

Killian De Smedt  committed 4efbbab

Made packed lists work and added extensive test for it

  • Participants
  • Parent commits 31af907
  • Branches default

Comments (0)

Files changed (3)

File protobuf/lexer.py

View file
 
 def _decode_attr(tokens, repeated=False):
     typ = next(tokens)[1]
-    typ = typ if typ  in _type_to_wiretype.keys() else "compose"
     name = next(tokens)[1]
-    wiretype = _type_to_wiretype[typ]
-    tag = encode_varint(int(next(tokens)[1]) << 3 | wiretype)
+    tag = int(next(tokens)[1])
     packed = False
     default = None
     while True:
         if extract == ';': break
         elif extract == 'default': default = next(tokens)[1]  # TODO: translate this to value of type...
         elif extract == 'packed' and next(tokens)[1] == 'true': packed = True 
-    return tag, Field(_type_encoder[typ], _type_decoder[typ], name, packed, repeated, default), tokens
+    typ = typ if typ  in _type_to_wiretype.keys() else "compose"
+    dectyp = typ if not packed else ('packed_' + typ)
+    wiretype = _type_to_wiretype[typ] if not packed else _type_to_wiretype["packed"]
+    tag = encode_varint(tag << 3 | wiretype)
+    return tag, Field(_type_encoder[typ], _type_decoder[dectyp], name, packed, repeated, default), tokens
 
 def _decode_message(tokens):
     name = next(tokens)[1]

File protobuf/message.py

View file
 from struct import Struct
 from array import array
 from collections import namedtuple
+from functools import partial
 
 # We could probably speed things up a lot by encoding the tags
 # when creating the messages and always work with those
         value, bytestring = _wiretype_decoder[tag[0] & 0b111](bytestring)
         yield (tag, value)
 
+def fn_varint(fn, bytestring):
+    varint, remaining = decode_varint(bytestring)
+    return fn(varint), remaining
+
 _type_decoder = {'fixed32': lambda x: Struct('<I').unpack(x)[0],
                  'sfixed32': lambda x: Struct('<i').unpack(x)[0],
                  'float': lambda x: Struct('<f').unpack(x)[0],
                  'string': lambda raw: raw.decode("UTF-8"),
                  'bytes': lambda raw: raw,
                  'composed': lambda raw: _decode_msg(raw),
-                 'repeated': lambda raw: raw,
                  'fixed64': lambda x: Struct('<Q').unpack(x)[0],
                  'sfixed64': lambda x: Struct('<q').unpack(x)[0],
                  'double': lambda x: Struct('<d').unpack(x)[0],
                  'sint32': lambda n:-(n >> 1) * ((n & 1) - 1) - (n & 1) * (1 + n >> 1),
                  'sint64': lambda n:-(n >> 1) * ((n & 1) - 1) - (n & 1) * (1 + n >> 1),
                  'bool': lambda raw: bool(raw),
-                 'enum': lambda raw: raw
+                 'enum': lambda raw: raw,
+                 'packed_fixed32': lambda x: (Struct('<I').unpack(x[:4])[0], x[4:]),
+                 'packed_sfixed32': lambda x: (Struct('<i').unpack(x[:4])[0], x[4:]),
+                 'packed_float': lambda x: (Struct('<f').unpack(x[:4])[0], x[4:]),
+                 'packed_fixed64': lambda x: (Struct('<Q').unpack(x[:8])[0], x[8:]),
+                 'packed_sfixed64': lambda x: (Struct('<q').unpack(x[:8])[0], x[8:]),
+                 'packed_double': lambda x: (Struct('<d').unpack(x[:8])[0], x[8:]),
+                 'packed_int32': partial(fn_varint, lambda n: Struct('i').unpack(Struct('I').pack(n))[0]),
+                 'packed_int64': partial(fn_varint, lambda n: Struct('q').unpack(Struct('Q').pack(n))[0]),
+                 'packed_uint32': decode_varint,
+                 'packed_uint64': decode_varint,
+                 'packed_sint32': partial(fn_varint, lambda n:-(n >> 1) * ((n & 1) - 1) - (n & 1) * (1 + n >> 1)),
+                 'packed_sint64': partial(fn_varint, lambda n:-(n >> 1) * ((n & 1) - 1) - (n & 1) * (1 + n >> 1)),
+                 'packed_bool': partial(fn_varint, bool),
+                 'packed_enum': lambda raw: raw,
                  }
 
 _type_encoder = {'fixed32': Struct('<I').pack,
     def __repr__(self):
         result = [str(self.__class__.__name__)]
         result.append(" (")
-        result.append(",".join("{} = {}".format(field.name, getattr(self, field.name, None)) for _, field in self.__class__._fields.items()))
+        result.append(", ".join("{} = {}".format(field.name, getattr(self, field.name, None)) for _, field in self.__class__._fields.items()))
         result.append(")")
         return ''.join(result)
     
     def parse_from_string(self, bytestring):
         for tag, value in _decode_msg(bytestring):
             f = self.__class__._fields.get(tag, Field(lambda x: None, lambda x: None, None, False, False, None))
-            if f.repeated:
+            if f.packed:
+                to_set = []
+                while len(value) > 0:
+                    dec_val, value = f.decode(value)
+                    to_set.append(dec_val)
+                setattr(self, f.name, to_set)
+            elif f.repeated:
                 getattr(self, f.name).append(f.decode(value)) if hasattr(self, f.name) else setattr(self, f.name, [f.decode(value)])
             elif f.name is not None:
                 setattr(self, f.name, f.decode(value))

File protobuf_test/test_lexer.py

View file
             expected += extra_encoding
             self.assertEqual(Test(**kwargs), Test(expected), "deserialize up to {}".format(varname))
             self.assertEqual(Test(**kwargs).serialize_to_string(), expected, "serialize up to {}".format(varname))
+            
+    def test_repeated_native_packed(self):
+        msg = """
+        message Test {
+            repeated double d = 1 [packed = true];
+            repeated float f = 2 [packed = true];
+            repeated int32 i32 = 3 [packed = true];
+            repeated int64 i64 = 4 [packed = true];
+            repeated uint32 u32 = 5 [packed = true];
+            repeated uint64 u64 = 6 [packed = true];
+            repeated sint32 s32 = 7 [packed = true];
+            repeated sint64 s64 = 8 [packed = true]; 
+            repeated fixed32 f32 = 9 [packed = true];
+            repeated fixed64 f64 = 10 [packed = true];
+            repeated sfixed32 sf32 = 11 [packed = true];
+            repeated sfixed64 sf64 = 12 [packed = true];
+            repeated bool b = 13 [packed = true];
+        }
+        """
+        Test = message_from_string(msg).Test
+        kwargs = {}
+        expected = b''
+        # Make use of the repeated things to check for some weird border-cases per type
+        tests = (
+                ('d', [10.0, 2 ** -1023, 0.0, -0.0, 1.0, -1.0, float('Inf'), float('-Inf')], b'\x0a\x40' + b''.join(pack('<d', x) for x in [10.0, 2 ** -1023, 0.0, -0.0, 1.0, -1.0, float('Inf'), float('-Inf')])),
+                ('f', [1234.0, 2 ** -127, 0.0, -0.0, 1.0, -1.0, float('Inf'), float('-Inf')], b'\x12\x20' + b''.join(pack('<f', x) for x in [1234.0, 2 ** -127, 0.0, -0.0, 1.0, -1.0, float('Inf'), float('-Inf')])),
+                ('i32', [-1, 0, 1, -2 ** 31, 2 ** 31 - 1], b'\x1a\x11\xff\xff\xff\xff\x0f\x00\x01\x80\x80\x80\x80\x08\xff\xff\xff\xff\x07'),
+                ('i64', [-1, 0, 1, -2 ** 63, 2 ** 63 - 1], b'\x22\x1f\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x01\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01\xff\xff\xff\xff\xff\xff\xff\xff\x7f'),
+                ('u32', [0, 1, 2 ** 32 - 1], b'\x2a\x07\x00\x01\xff\xff\xff\xff\x0f'),
+                ('u64', [0, 1, 2 ** 64 - 1], b'\x32\x0c\x00\x01\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
+                ('s32', [-1, 0, 1, -2 ** 31, 2 ** 31 - 1], b'\x3a\x0d\x01\x00\x02\xff\xff\xff\xff\x0f\xfe\xff\xff\xff\x0f'),
+                ('s64', [-1, 0, 1, -2 ** 63, 2 ** 63 - 1], b'\x42\x17\x01\x00\x02\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\xfe\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
+                ('f32', [0, 1, 2 ** 32 - 1], b'\x4a\x0c\x00\x00\x00\x00\x01\x00\x00\x00\xff\xff\xff\xff'),
+                ('f64', [0, 1, 2 ** 64 - 1], b'\x52\x18\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff'),
+                ('sf32', [-1, 0, 1, -2 ** 31, 2 ** 31 - 1], b'\x5a\x14\xff\xff\xff\xff\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\x7f'),
+                ('sf64', [-1, 0, 1, -2 ** 63, 2 ** 63 - 1], b'\x62\x28\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80\xff\xff\xff\xff\xff\xff\xff\x7f'),
+                ('b', [True, False, True], b'\x6a\x03\x01\x00\x01'),
+                 )
+        for (varname, val, extra_encoding) in tests:
+            kwargs[varname] = val
+            expected += extra_encoding
+            self.assertEqual(Test(**kwargs), Test(expected), "deserialize up to {}".format(varname))
+            self.assertEqual(Test(**kwargs).serialize_to_string(), expected, "serialize up to {}".format(varname))
 
 if __name__ == "__main__":
     # import sys;sys.argv = ['', 'Test.testName']