Commits

mattip committed 881739f

fix, add more passing tests

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_iter.py

             self.done = True
 
     def get_dim_index(self):
-        return self.indices[0]
+        return self.indices[self.dimorder[0]]

pypy/module/micronumpy/interp_numarray.py

                 return out
             return Scalar(res_dtype, res_dtype.box(result))
         def do_axisminmax(self, space, axis, out):
+            # Use a AxisFirstIterator to walk along self, with dimensions
+            # reordered to move along 'axis' fastest. Every time 'axis' 's
+            # index is 0, move to the next value of out.
             dtype = self.find_dtype()
             source = AxisFirstIterator(self, axis)
             dest = ViewIterator(out.start, out.strides, out.backstrides, 
             firsttime = True
             while not source.done:
                 cur_val = self.getitem(source.offset)
-                #print 'indices are',source.indices
                 cur_index = source.get_dim_index()
                 if cur_index == 0:
                     if not firsttime:
                     firsttime = False    
                     cur_best = cur_val
                     out.setitem(dest.offset, dtype.box(0))
-                    #print 'setting out[',dest.offset,'] to 0'
                 else:
                     new_best = getattr(dtype.itemtype, op_name)(cur_best, cur_val)
                     if dtype.itemtype.ne(new_best, cur_best):
                         cur_best = new_best
                         out.setitem(dest.offset, dtype.box(cur_index))
-                        #print 'setting out[',dest.offset,'] to',cur_index
                 source.next()
             return out
 

pypy/module/micronumpy/test/test_numarray.py

         assert a.argmax() == 5
         assert a[:2, ].argmax() == 3
 
+    def test_argmax_axis(self):
+        from _numpypy import array
+        # Some random values, tested via cut-and-paste
+        # from numpy
+        vals = [57, 42, 57, 20, 81, 82, 65, 16, 52, 32,
+                24, 95, 99,  4, 86, 60, 38, 28, 67, 45,
+                68, 66, 13, 76, 98, 96, 61,  4,  0, 13,
+                94, 30, 36, 89, 31, 54, 43,  6, 58, 84,
+                15, 22, 41,  3, 49, 81, 65, 53, 85, 14, 
+                56, 37, 60, 11, 77, 9, 16, 80, 94, 43]
+        a = array(vals).reshape(5,3,4)
+        b = a.argmax(0)
+        assert (b == [[1, 2, 1, 3],
+                      [0, 0, 2, 1],
+                      [1, 2, 4, 0]]).all()
+        b = a.argmax(1)
+        assert (b == [[1, 1, 1, 2],
+                      [0, 2, 0, 2],
+                      [0, 0, 1, 2],
+                      [2, 2, 2, 0],
+                      [0, 2, 2, 2]]).all()
+        b = a.argmax(2)
+        assert (b == [[0, 1, 3], [0, 2, 3],
+                      [0, 2, 1], [3, 2, 1],
+                      [0, 2, 2]]).all()
+        b = a[:,2,:].argmax(1)
+        assert(b == [3, 3, 1, 1, 2]).all()
+
+
+
     def test_broadcast_wrong_shapes(self):
         from _numpypy import zeros
         a = zeros((4, 3, 2))