Commits

Brian Kearns committed 401a27a

fix scalar view as subtype of ndarray

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_boxes.py

 from pypy.objspace.std.unicodetype import unicode_typedef, unicode_from_object
 from pypy.objspace.std.inttype import int_typedef
 from pypy.objspace.std.complextype import complex_typedef
+from pypy.objspace.std.typeobject import W_TypeObject
 from rpython.rlib.rarithmetic import LONG_BIT
 from rpython.rtyper.lltypesystem import rffi
 from rpython.tool.sourcetools import func_with_new_name
 from pypy.module.micronumpy.arrayimpl.voidbox import VoidBoxStorage
+from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.interp_flagsobj import W_FlagsObject
 from pypy.interpreter.mixedmodule import MixedModule
 from rpython.rtyper.lltypesystem import lltype
 
     def descr_view(self, space, w_dtype):
         from pypy.module.micronumpy.interp_dtype import W_Dtype
-        dtype = space.interp_w(W_Dtype,
-            space.call_function(space.gettypefor(W_Dtype), w_dtype))
-        if dtype.get_size() == 0:
-            raise OperationError(space.w_TypeError, space.wrap(
-                "data-type must not be 0-sized"))
-        if dtype.get_size() != self.get_dtype(space).get_size():
-            raise OperationError(space.w_ValueError, space.wrap(
-                "new type not compatible with array."))
+        if type(w_dtype) is W_TypeObject and \
+                space.abstract_issubclass_w(w_dtype, space.gettypefor(W_NDimArray)):
+            dtype = self.get_dtype(space)
+        else:
+            dtype = space.interp_w(W_Dtype,
+                space.call_function(space.gettypefor(W_Dtype), w_dtype))
+            if dtype.get_size() == 0:
+                raise OperationError(space.w_TypeError, space.wrap(
+                    "data-type must not be 0-sized"))
+            if dtype.get_size() != self.get_dtype(space).get_size():
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "new type not compatible with array."))
         if dtype.is_str_or_unicode():
             return dtype.coerce(space, space.wrap(self.raw_str()))
         elif dtype.is_record_type():

pypy/module/micronumpy/test/test_subtype.py

         b = matrix(a)
         assert isinstance(b, matrix)
         assert (b == a).all()
+        a = array(5)[()]
+        for s in [matrix, ndarray]:
+            b = a.view(s)
+            assert b == a
+            assert type(b) is type(a)
 
     def test_subtype_like_matrix(self):
         import numpy as np