Commits

mattip committed b3e61ca

fix for axis arg

  • Participants
  • Parent commits 7f81f9f
  • Branches nupypy-axis-arg-check

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

 from pypy.rlib.rstring import StringBuilder
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.tool.sourcetools import func_with_new_name
+from pypy.rlib.rarithmetic import maxint
 
 
 count_driver = jit.JitDriver(
     def _reduce_ufunc_impl(ufunc_name, promote_to_largest=False):
         def impl(self, space, w_axis=None, w_out=None):
             if space.is_w(w_axis, space.w_None):
-                axis = -1
+                axis = maxint
             else:
                 axis = space.int_w(w_axis)
+                shapelen = len(self.shape)
+                if axis < -shapelen or axis>= shapelen:
+                    raise operationerrfmt(space.w_ValueError,
+                        "axis entry %d is out of bounds [%d, %d)", axis,
+                        -shapelen, shapelen)
+                if axis < 0:
+                    axis += shapelen
+
             if space.is_w(w_out, space.w_None) or not w_out:
                 out = None
             elif not isinstance(w_out, BaseArray):
 
     def descr_mean(self, space, w_axis=None, w_out=None):
         if space.is_w(w_axis, space.w_None):
-            w_axis = space.wrap(-1)
             w_denom = space.wrap(support.product(self.shape))
         else:
-            dim = space.int_w(w_axis)
-            w_denom = space.wrap(self.shape[dim])
+            axis = space.int_w(w_axis)
+            shapelen = len(self.shape)
+            if axis < -shapelen or axis>= shapelen:
+                raise operationerrfmt(space.w_ValueError,
+                    "axis entry %d is out of bounds [%d, %d)", axis,
+                    -shapelen, shapelen)
+            if axis < 0:    
+                axis += shapelen
+            w_denom = space.wrap(self.shape[axis])
         return space.div(self.descr_sum_promote(space, w_axis, w_out), w_denom)
 
     def descr_var(self, space, w_axis=None):
     def __init__(self, ufunc, name, identity, shape, dtype, left, right, dim):
         Call2.__init__(self, ufunc, name, shape, dtype, dtype,
                        left, right)
+        assert dim >= 0
         self.dim = dim
         self.identity = identity
 
     if space.is_w(w_axis, space.w_None):
         return space.wrap(support.product(arr.shape))
     if space.isinstance_w(w_axis, space.w_int):
-        return space.wrap(arr.shape[space.int_w(w_axis)])
+        axis = space.int_w(w_axis)
+        if axis < -arr.shapelen or axis>= arr.shapelen:
+            raise operationerrfmt(space.w_ValueError,
+                "axis entry %d is out of bounds [%d, %d)", axis,
+                -arr.shapelen, arr.shapelen)
+        return space.wrap(arr.shape[axis])    
+    # numpy as of June 2012 does not implement this 
     s = 1
     elems = space.fixedview(w_axis)
     for w_elem in elems:
-        s *= arr.shape[space.int_w(w_elem)]
+        axis = space.int_w(w_elem)
+        if axis < -arr.shapelen or axis>= arr.shapelen:
+            raise operationerrfmt(space.w_ValueError,
+                "axis entry %d is out of bounds [%d, %d)", axis,
+                -arr.shapelen, arr.shapelen)
+        s *= arr.shape[axis]
     return space.wrap(s)
 
 def dot(space, w_obj, w_obj2):

pypy/module/micronumpy/interp_ufuncs.py

 from pypy.rlib import jit
 from pypy.rlib.rarithmetic import LONG_BIT
 from pypy.tool.sourcetools import func_with_new_name
-
+from pypy.rlib.rarithmetic import maxint
 
 class W_Ufunc(Wrappable):
     _attrs_ = ["name", "promote_to_float", "promote_bools", "identity"]
         if w_axis is None:
             axis = 0
         elif space.is_w(w_axis, space.w_None):
-            axis = -1
+            axis = maxint
         else:
             axis = space.int_w(w_axis)
+            shapelen = len(self.shape)
+            if axis < -shapelen or axis>= shapelen:
+                raise operationerrfmt(space.w_ValueError,
+                    "axis entry %d is out of bounds [%d, %d)", axis,
+                    -shapelen, shapelen)
+            if axis < 0:
+                axis += shapelen
         if space.is_w(w_out, space.w_None):
             out = None
         elif not isinstance(w_out, BaseArray):
             raise OperationError(space.w_ValueError, space.wrap("reduce only "
                 "supported for binary functions"))
         assert isinstance(self, W_Ufunc2)
+        assert axis>=0
         obj = convert_to_array(space, w_obj)
-        if axis >= len(obj.shape):
-            raise OperationError(space.w_ValueError, space.wrap("axis(=%d) out of bounds" % axis))
         if isinstance(obj, Scalar):
             raise OperationError(space.w_TypeError, space.wrap("cannot reduce "
                 "on a scalar"))
         if self.identity is None and size == 0:
             raise operationerrfmt(space.w_ValueError, "zero-size array to "
                     "%s.reduce without identity", self.name)
-        if shapelen > 1 and axis >= 0:
+        if shapelen > 1 and axis < shapelen:
             if keepdims:
                 shape = obj.shape[:axis] + [1] + obj.shape[axis + 1:]
             else:

pypy/module/micronumpy/test/test_numarray.py

         a = array([True, False, True, False], dtype="?")
         b = array([True, True, False, False], dtype="?")
         c = a + b
+        print 'c.dtype',c.dtype
+        print 'c',c,'a',a,'b',b
+        print 'a+b',a+b
         for i in range(4):
             assert c[i] == bool(a[i] + b[i])
 
         assert (b == array(range(35, 70), dtype=float).reshape(5, 7)).all()
         assert (a.mean(2) == array(range(0, 15), dtype=float).reshape(3, 5) * 7 + 3).all()
         assert (arange(10).reshape(5, 2).mean(axis=1) == [0.5, 2.5, 4.5, 6.5, 8.5]).all()
+        assert (a.mean(axis=-1) == a.mean(axis=2)).all()
+        raises(ValueError, a.mean, -4)
+        raises(ValueError, a.mean, 3)
 
     def test_sum(self):
         from _numpypy import array
         a = array([True] * 5, bool)
         assert a.sum() == 5
 
-        raises(TypeError, 'a.sum(2, 3)')
+        raises(TypeError, 'a.sum(axis=0, out=3)')
+        raises(ValueError, 'a.sum(axis=2)')
         d = array(0.)
         b = a.sum(out=d)
         assert b == d