Brian Kearns avatar Brian Kearns committed e233c50

fix some ndarray.astype cases

Comments (0)

Files changed (3)

pypy/module/micronumpy/arrayimpl/concrete.py

                                                     self.order)
         impl = ConcreteArray(self.get_shape(), dtype, self.order,
                              strides, backstrides)
-        if self.dtype.is_str_or_unicode() and not dtype.is_str_or_unicode():
-            raise OperationError(space.w_NotImplementedError, space.wrap(
-                "astype(%s) not implemented yet" % self.dtype))
-        else:
-            loop.setslice(space, impl.get_shape(), impl, self)
+        loop.setslice(space, impl.get_shape(), impl, self)
         return impl
 
 

pypy/module/micronumpy/interp_numarray.py

         return contig.argsort(space, w_axis)
 
     def descr_astype(self, space, w_dtype):
-        dtype = space.interp_w(interp_dtype.W_Dtype,
+        cur_dtype = self.get_dtype()
+        new_dtype = space.interp_w(interp_dtype.W_Dtype,
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype))
+        if new_dtype.shape:
+            raise oefmt(space.w_NotImplementedError,
+                "%s.astype(%s) not implemented yet", cur_dtype.name, new_dtype.name)
         impl = self.implementation
         if isinstance(impl, scalar.Scalar):
-            return W_NDimArray.new_scalar(space, dtype, impl.value)
+            return W_NDimArray.new_scalar(space, new_dtype, impl.value)
         else:
-            new_impl = impl.astype(space, dtype)
+            new_impl = impl.astype(space, new_dtype)
             return wrap_impl(space, space.type(self), self, new_impl)
 
     def descr_get_base(self, space):

pypy/module/micronumpy/test/test_numarray.py

         a = array(3.1415).astype('S3').dtype
         assert a.itemsize == 3
 
-        import sys
-        if '__pypy__' not in sys.builtin_module_names:
-            a = array(['1', '2','3']).astype(float)
-            assert a[2] == 3.0
-        else:
-            raises(NotImplementedError, array(['1', '2', '3']).astype, float)
+        a = array(['1', '2','3']).astype(float)
+        assert a[2] == 3.0
 
         a = array('123')
         assert a.astype('i8') == 123
         exc = raises(ValueError, "a[0, 0]['z']")
         assert exc.value.message == 'field named z not found'
 
+        import sys
+        a = array(1.5, dtype=float)
+        assert a.shape == ()
+        if '__pypy__' not in sys.builtin_module_names:
+            a = a.astype((float, 2))
+            repr(a)  # check for crash
+            assert a.shape == (2,)
+            assert tuple(a) == (1.5, 1.5)
+        else:
+            raises(NotImplementedError, "a.astype((float, 2))")
+
+        a = array([1.5], dtype=float)
+        assert a.shape == (1,)
+        if '__pypy__' not in sys.builtin_module_names:
+            a = a.astype((float, 2))
+            repr(a)  # check for crash
+            assert a.shape == (1, 2)
+            assert tuple(a[0]) == (1.5, 1.5)
+        else:
+            raises(NotImplementedError, "a.astype((float, 2))")
+
     def test_subarray_multiple_rows(self):
         import numpypy as np
         descr = [
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.