Commits

mattip committed d5e489e

test, fix for subtype pickle numpy compatability, including quirks

  • Participants
  • Parent commits 92ff43b

Comments (0)

Files changed (2)

File pypy/module/micronumpy/interp_numarray.py

         multiarray = numpypy.get("multiarray")
         assert isinstance(multiarray, MixedModule)
         reconstruct = multiarray.get("_reconstruct")
-
-        parameters = space.newtuple([space.gettypefor(W_NDimArray), space.newtuple([space.wrap(0)]), space.wrap("b")])
+        parameters = space.newtuple([self.getclass(space),
+                        space.newtuple([space.wrap(0)]), space.wrap("b")])
 
         builder = StringBuilder()
         if isinstance(self.implementation, SliceArray):
         return space.newtuple([reconstruct, parameters, state])
 
     def descr_setstate(self, space, w_state):
-        from rpython.rtyper.lltypesystem import rffi
-
-        shape = space.getitem(w_state, space.wrap(1))
-        dtype = space.getitem(w_state, space.wrap(2))
-        assert isinstance(dtype, interp_dtype.W_Dtype)
-        isfortran = space.getitem(w_state, space.wrap(3))
-        storage = space.getitem(w_state, space.wrap(4))
-
+        lens = space.len_w(w_state)
+        # numpy compatability, see multiarray/methods.c
+        if lens == 5:
+            base_index = 1
+        elif lens == 4:
+            base_index = 0
+        else:
+            raise OperationError(space.w_ValueError, space.wrap(
+                 "__setstate__ called with len(args[1])==%d, not 5 or 4" % lens))
+        shape = space.getitem(w_state, space.wrap(base_index))
+        dtype = space.getitem(w_state, space.wrap(base_index+1))
+        isfortran = space.getitem(w_state, space.wrap(base_index+2))
+        storage = space.getitem(w_state, space.wrap(base_index+3))
+        if not isinstance(dtype, interp_dtype.W_Dtype):
+            raise OperationError(space.w_ValueError, space.wrap(
+                 "__setstate__(self, (shape, dtype, .. called with improper dtype '%r'" % dtype))
         self.implementation = W_NDimArray.from_shape_and_storage(space,
                 [space.int_w(i) for i in space.listview(shape)],
                 rffi.str2charp(space.str_w(storage), track_allocation=False),

File pypy/module/micronumpy/test/test_subtype.py

 
 
 class AppTestSupport(BaseNumpyAppTest):
+    spaceconfig = dict(usemodules=["micronumpy", "struct", "binascii"])
     def setup_class(cls):
         BaseNumpyAppTest.setup_class.im_func(cls)
         cls.w_NoNew = cls.space.appexec([], '''():
         a = matrix([[1., 2.]])
         b = N.array([a])
 
+    def test_setstate_no_version(self):
+        # Some subclasses of ndarray, like MaskedArray, do not use
+        # version in __setstare__
+        from numpy import ndarray, array
+        from pickle import loads, dumps
+        import sys, new
+        class D(ndarray):
+            ''' A subtype with a constructor that accepts a list of
+                data values, where ndarray accepts a shape
+            '''
+            def __new__(subtype, data, dtype=None, copy=True):
+                arr = array(data, dtype=dtype, copy=copy)
+                shape = arr.shape
+                ret = ndarray.__new__(subtype, shape, arr.dtype,
+                                        buffer=arr,
+                                        order=True)
+                return ret
+            def __setstate__(self, state):
+                (version, shp, typ, isf, raw) = state
+                ndarray.__setstate__(self, (shp, typ, isf, raw))
 
+        D.__module__ = 'mod'
+        mod = new.module('mod')
+        mod.D = D
+        sys.modules['mod'] = mod
+        a = D([1., 2.])
+        s = dumps(a)
+        #Taken from numpy version 1.8
+        s_from_numpy = '''ignore this line
+            _reconstruct
+            p0
+            (cmod
+            D
+            p1
+            (I0
+            tp2
+            S'b'
+            p3
+            tp4
+            Rp5
+            (I1
+            (I2
+            tp6
+            cnumpy
+            dtype
+            p7
+            (S'f8'
+            p8
+            I0
+            I1
+            tp9
+            Rp10
+            (I3
+            S'<'
+            p11
+            NNNI-1
+            I-1
+            I0
+            tp12
+            bI00
+            S'\x00\x00\x00\x00\x00\x00\xf0?\x00\x00\x00\x00\x00\x00\x00@'
+            p13
+            tp14
+            b.'''.replace('            ','')
+        for ss,sn in zip(s.split('\n')[1:],s_from_numpy.split('\n')[1:]):
+            if len(ss)>10:
+                # ignore binary data, it will be checked later
+                continue
+            assert ss == sn
+        b = loads(s)
+        assert (a == b).all()
+        assert isinstance(b, D)