Commits

Brian Kearns committed 34f8c67

test and fix for passing out=None to ndarray.clip/choose

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

 
     @unwrap_spec(mode=str)
     def descr_choose(self, space, w_choices, mode='raise', w_out=None):
-        if w_out is not None and not isinstance(w_out, W_NDimArray):
+        if not space.is_none(w_out) and not isinstance(w_out, W_NDimArray):
             raise OperationError(space.w_TypeError, space.wrap(
                 "return arrays must be of ArrayType"))
         return interp_arrayops.choose(space, self, w_choices, w_out, mode)
 
     def descr_clip(self, space, w_min, w_max, w_out=None):
-        if w_out is not None and not isinstance(w_out, W_NDimArray):
+        if not space.is_none(w_out) and not isinstance(w_out, W_NDimArray):
             raise OperationError(space.w_TypeError, space.wrap(
                 "return arrays must be of ArrayType"))
         min = convert_to_array(space, w_min)

pypy/module/micronumpy/test/test_arrayops.py

     def test_choose_out(self):
         from _numpypy import array
         a, b, c = array([1, 2, 3]), [4, 5, 6], 13
+        r = array([2, 1, 0]).choose([a, b, c], out=None)
+        assert (r == [13, 5, 3]).all()
+        assert (a == [1, 2, 3]).all()
         r = array([2, 1, 0]).choose([a, b, c], out=a)
         assert (r == [13, 5, 3]).all()
         assert (a == [13, 5, 3]).all()
-        
+
     def test_choose_modes(self):
         from _numpypy import array
         a, b, c = array([1, 2, 3]), [4, 5, 6], 13

pypy/module/micronumpy/test/test_numarray.py

         from _numpypy import array
         a = array([1, 2, 17, -3, 12])
         assert (a.clip(-2, 13) == [1, 2, 13, -2, 12]).all()
-        assert (a.clip(-1, 1) == [1, 1, 1, -1, 1]).all()
+        assert (a.clip(-1, 1, out=None) == [1, 1, 1, -1, 1]).all()
+        assert (a == [1, 2, 17, -3, 12]).all()
         assert (a.clip(-1, [1, 2, 3, 4, 5]) == [1, 2, 3, -1, 5]).all()
         assert (a.clip(-2, 13, out=a) == [1, 2, 13, -2, 12]).all()
         assert (a == [1, 2, 13, -2, 12]).all()