Maciej Fijalkowski avatar Maciej Fijalkowski committed 0fcad0c

implement keepdims=True

Comments (0)

Files changed (5)

pypy/module/micronumpy/interp_iter.py

 class AxisIterator(BaseIterator):
     def __init__(self, start, dim, shape, strides, backstrides):
         self.res_shape = shape[:]
-        self.strides = strides[:dim] + [0] + strides[dim:]
-        self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
+        if len(shape) == len(strides):
+            # keepdims = True
+            self.strides = strides[:dim] + [0] + strides[dim + 1:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim + 1:]
+        else:
+            self.strides = strides[:dim] + [0] + strides[dim:]
+            self.backstrides = backstrides[:dim] + [0] + backstrides[dim:]
         self.first_line = True
         self.indices = [0] * len(shape)
         self._done = False

pypy/module/micronumpy/interp_numarray.py

 def array(space, w_item_or_iterable, w_dtype=None, w_order=None,
           subok=True, copy=False, w_maskna=None, ownmaskna=False):
     # find scalar
+    if w_maskna is None:
+        w_maskna = space.w_None
     if (not subok or copy or not space.is_w(w_maskna, space.w_None) or
         ownmaskna):
         raise OperationError(space.w_NotImplementedError, space.wrap("Unsupported args"))
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
         )
         return scalar_w(space, dtype, w_item_or_iterable)
-    if space.is_w(w_order, space.w_None):
+    if space.is_w(w_order, space.w_None) or w_order is None:
         order = 'C'
     else:
         order = space.str_w(w_order)

pypy/module/micronumpy/interp_ufuncs.py

         return self.identity
 
     def descr_call(self, space, __args__):
-        if __args__.keywords or len(__args__.arguments_w) < self.argcount:
+        # XXX do something with strange keywords
+        if len(__args__.arguments_w) < self.argcount:
             raise OperationError(space.w_ValueError,
                 space.wrap("invalid number of arguments")
             )
 
     @unwrap_spec(skipna=bool, keepdims=bool)
     def descr_reduce(self, space, w_obj, w_axis=None, w_dtype=None,
-                     skipna=False, keepdims=True, w_out=None):
+                     skipna=False, keepdims=False, w_out=None):
         """reduce(...)
         reduce(a, axis=0)
 
             axis = -1
         else:
             axis = space.int_w(w_axis)
-        return self.reduce(space, w_obj, False, False, axis)
+        return self.reduce(space, w_obj, False, False, axis, keepdims)
 
-    def reduce(self, space, w_obj, multidim, promote_to_largest, dim):
+    def reduce(self, space, w_obj, multidim, promote_to_largest, dim, keepdims):
         from pypy.module.micronumpy.interp_numarray import convert_to_array, \
                                                            Scalar
         if self.argcount != 2:
             raise operationerrfmt(space.w_ValueError, "zero-size array to "
                     "%s.reduce without identity", self.name)
         if shapelen > 1 and dim >= 0:
-            res = self.do_axis_reduce(obj, dtype, dim)
+            res = self.do_axis_reduce(obj, dtype, dim, keepdims)
             return space.wrap(res)
         scalarsig = ScalarSignature(dtype)
         sig = find_sig(ReduceSignature(self.func, self.name, dtype,
             value = self.identity.convert_to(dtype)
         return self.reduce_loop(shapelen, sig, frame, value, obj, dtype)
 
-    def do_axis_reduce(self, obj, dtype, dim):
+    def do_axis_reduce(self, obj, dtype, dim, keepdims):
         from pypy.module.micronumpy.interp_numarray import AxisReduce,\
              W_NDimArray
-        
-        shape = obj.shape[0:dim] + obj.shape[dim + 1:len(obj.shape)]
+
+        if keepdims:
+            shape = obj.shape[:dim] + [1] + obj.shape[dim + 1:]
+        else:
+            shape = obj.shape[:dim] + obj.shape[dim + 1:]
         size = 1
         for s in shape:
             size *= s

pypy/module/micronumpy/strides.py

 from pypy.rlib import jit
-
+from pypy.interpreter.error import OperationError
 
 @jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
     jit.isconstant(len(chunks))

pypy/module/micronumpy/test/test_ufuncs.py

         from _numpypy import sin, add
 
         raises(ValueError, sin.reduce, [1, 2, 3])
-        raises(ValueError, add.reduce, 1)
+        raises(TypeError, add.reduce, 1)
 
     def test_reduce_1d(self):
         from _numpypy import add, maximum
         assert (add.reduce(a, 0) == [12, 15, 18, 21]).all()
         assert (add.reduce(a, 1) == [6.0, 22.0, 38.0]).all()
 
+    def test_reduce_keepdims(self):
+        from _numpypy import add, arange
+        a = arange(12).reshape(3, 4)
+        b = add.reduce(a, 0, keepdims=True)
+        assert b.shape == (1, 4)
+        assert (add.reduce(a, 0, keepdims=True) == [12, 15, 18, 21]).all()
+        
+
     def test_bitwise(self):
         from _numpypy import bitwise_and, bitwise_or, arange, array
         a = arange(6).reshape(2, 3)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.