Commits

Maciej Fijalkowski committed e28b7a7

implement diagonal. I wonder if the jit would help here

  • Participants
  • Parent commits 254c849
  • Branches missing-ndarray-attributes

Comments (0)

Files changed (4)

pypy/module/micronumpy/interp_arrayops.py

 
 def diagonal(space, arr, offset, axis1, axis2):
     shape = arr.get_shape()
+    shapelen = len(shape)
+    if offset < 0:
+        offset = -offset
+        axis1, axis2 = axis2, axis1
     size = min(shape[axis1], shape[axis2] - offset)
     dtype = arr.dtype
-    if len(shape) == 2:
+    if axis1 < axis2:
+        shape = (shape[:axis1] + shape[axis1 + 1:axis2] +
+                 shape[axis2 + 1:] + [size])
+    else:
+        shape = (shape[:axis2] + shape[axis2 + 1:axis1] +
+                 shape[axis1 + 1:] + [size])
+    out = W_NDimArray.from_shape(shape, dtype)
+    if size == 0:
+        return out
+    if shapelen == 2:
         # simple case
-        out = W_NDimArray.from_shape([size], dtype)
         loop.diagonal_simple(space, arr, out, offset, axis1, axis2, size)
-        return out
     else:
-        xxx
+        loop.diagonal_array(space, arr, out, offset, axis1, axis2, shape)
+    return out

pypy/module/micronumpy/interp_numarray.py

             raise operationerrfmt(space.w_ValueError,
                  "axis1(=%d) and axis2(=%d) must be withing range (ndim=%d)",
                                   axis1, axis2, len(self.get_shape()))
+        if axis1 == axis2:
+            raise OperationError(space.w_ValueError, space.wrap(
+                "axis1 and axis2 cannot be the same"))
         return interp_arrayops.diagonal(space, self.implementation, offset,
                                         axis1, axis2)
     

pypy/module/micronumpy/loop.py

 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
-from pypy.module.micronumpy.iter import PureShapeIterator
+from pypy.module.micronumpy.iter import PureShapeIterator, ConcreteArrayIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
 
         out_iter.setitem(arr.getitem_index(space, index))
         i += 1
         out_iter.next()
+
+def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
+    out_iter = out.create_iter()
+    iter = PureShapeIterator(shape, [])
+    shapelen = len(shape)
+    while not iter.done():
+        last_index = iter.indexes[-1]
+        if axis1 < axis2:
+            indexes = (iter.indexes[:axis1] + [last_index] +
+                       iter.indexes[axis1:axis2 - 1] + [last_index + offset] +
+                       iter.indexes[axis2 - 1:shapelen - 1])
+        else:
+            indexes = (iter.indexes[:axis2] + [last_index + offset] +
+                       iter.indexes[axis2:axis1 - 1] + [last_index] +
+                       iter.indexes[axis1 - 1:shapelen - 1])
+        out_iter.setitem(arr.getitem_index(space, indexes))
+        iter.next()
+        out_iter.next()

pypy/module/micronumpy/test/test_numarray.py

         from _numpypy import array
         a = array([[1, 2], [3, 4], [5, 6]])
         raises(ValueError, 'array([1, 2]).diagonal()')
+        raises(ValueError, 'a.diagonal(0, 0, 0)')
+        raises(ValueError, 'a.diagonal(0, 0, 13)')
         assert (a.diagonal() == [1, 4]).all()
         assert (a.diagonal(1) == [2]).all()
 
+    def test_diagonal_axis(self):
+        from _numpypy import arange
+        a = arange(12).reshape(2, 3, 2)
+        assert (a.diagonal(0, 0, 1) == [[0, 8], [1, 9]]).all()
+        assert a.diagonal(3, 0, 1).shape == (2, 0)
+        assert (a.diagonal(1, 0, 1) == [[2, 10], [3, 11]]).all()         
+        assert (a.diagonal(0, 2, 1) == [[0, 3], [6, 9]]).all()        
+        assert (a.diagonal(2, 2, 1) == [[4], [10]]).all()        
+        assert (a.diagonal(1, 2, 1) == [[2, 5], [8, 11]]).all()        
+
+    def test_diagonal_axis_neg_ofs(self):
+        from _numpypy import arange
+        a = arange(12).reshape(2, 3, 2)
+        assert (a.diagonal(-1, 0, 1) == [[6], [7]]).all()
+        assert a.diagonal(-2, 0, 1).shape == (2, 0)
+
 class AppTestSupport(BaseNumpyAppTest):
     def setup_class(cls):
         import struct