Commits

Romain Guillebert committed 25afd81

Implement __array_prepare__ for non-scalar

Comments (0)

Files changed (2)

pypy/module/micronumpy/loop.py

 from pypy.module.micronumpy.iter import PureShapeIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
+from pypy.module.micronumpy import interp_boxes
 
-call2_driver = jit.JitDriver(name='numpy_call2',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
-                                     'left_iter', 'right_iter', 'out_iter'])
-
-def call_prepare(self, space, w_out, w_obj, w_result):
-    if isinstance(w_out, W_NDimArray):
-        w_array = space.lookup(w_out, "__array_prepare__")
-        w_caller = w_out
-    else:
-        w_array = space.lookup(w_obj, "__array_prepare__")
-        w_caller = w_obj
+def call_prepare(space, w_obj, w_result):
+    w_array = space.lookup(w_obj, "__array_prepare__")
     if w_array:
-        w_retVal = space.get_and_call_function(w_array, w_caller, w_result, None)
+        w_retVal = space.get_and_call_function(w_array, w_obj, w_result, None)
         if not isinstance(w_retVal, W_NDimArray) and \
             not isinstance(w_retVal, interp_boxes.Box):
             raise OperationError(space.w_ValueError,
         return w_retVal
     return w_result
 
+call2_driver = jit.JitDriver(name='numpy_call2',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
+                                     'left_iter', 'right_iter', 'out_iter'])
 def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     # handle array_priority
     # w_lhs and w_rhs could be of different ndarray subtypes. Numpy does:
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
+        out = call_prepare(space, w_lhs, out)
+    else:
+        out = call_prepare(space, out, out)
+
     left_iter = w_lhs.create_iter(shape)
     right_iter = w_rhs.create_iter(shape)
     out_iter = out.create_iter(shape)
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
+        out = call_prepare(space, w_obj, out)
+    else:
+        out = call_prepare(space, out, out)
     obj_iter = w_obj.create_iter(shape)
     out_iter = out.create_iter(shape)
     shapelen = len(shape)

pypy/module/micronumpy/test/test_subtype.py

         assert type(x) == ndarray
         assert a.called_wrap
 
-    def test___array_prepare__2arg(self):
+    def test___array_prepare__2arg_scalar(self):
         from numpypy import ndarray, array, add, ones
         class with_prepare(ndarray):
             def __array_prepare__(self, arr, context):
         assert x.called_prepare
         raises(TypeError, add, a, b, out=c)
 
-    def test___array_prepare__1arg(self):
+    def test___array_prepare__1arg_scalar(self):
         from numpypy import ndarray, array, log, ones
         class with_prepare(ndarray):
             def __array_prepare__(self, arr, context):
         assert x.called_prepare
         raises(TypeError, log, a, out=c)
 
+    def test___array_prepare__2arg_array(self):
+        from numpypy import ndarray, array, add, ones
+        class with_prepare(ndarray):
+            def __array_prepare__(self, arr, context):
+                retVal = array(arr).view(type=with_prepare)
+                retVal.called_prepare = True
+                return retVal
+        class with_prepare_fail(ndarray):
+            called_prepare = False
+            def __array_prepare__(self, arr, context):
+                return array(arr[0]).view(type=with_prepare)
+        a = array([1])
+        b = array([1]).view(type=with_prepare)
+        x = add(a, a, out=b)
+        assert x == 2
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        b.called_prepare = False
+        a = ones((3, 2)).view(type=with_prepare)
+        b = ones((3, 2))
+        c = ones((3, 2)).view(type=with_prepare_fail)
+        x = add(a, b, out=a)
+        assert (x == 2).all()
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        raises(TypeError, add, a, b, out=c)
+
+    def test___array_prepare__1arg_array(self):
+        from numpypy import ndarray, array, log, ones
+        class with_prepare(ndarray):
+            def __array_prepare__(self, arr, context):
+                retVal = array(arr).view(type=with_prepare)
+                retVal.called_prepare = True
+                return retVal
+        class with_prepare_fail(ndarray):
+            def __array_prepare__(self, arr, context):
+                return array(arr[0]).view(type=with_prepare)
+        a = array([1])
+        b = array([1]).view(type=with_prepare)
+        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+        x = log(a, out=b)
+        print 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa'
+        assert x == 0
+        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        x.called_prepare = False
+        a = ones((3, 2)).view(type=with_prepare)
+        b = ones((3, 2))
+        c = ones((3, 2)).view(type=with_prepare_fail)
+        x = log(a)
+        assert (x == 0).all()
+        assert type(x) == with_prepare
+        assert x.called_prepare
+        raises(TypeError, log, a, out=c)
 
     def test___array_prepare__reduce(self):
         from numpypy import ndarray, array, sum, ones, add