Commits

Romain Guillebert  committed cbc60b2

Implement the c_index and the f_index flags on the nditer class

  • Participants
  • Parent commits 27ffa87
  • Branches numpypy-nditer

Comments (0)

Files changed (2)

File pypy/module/micronumpy/interp_nditer.py

                                     shape, backward)
     return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
 
+def is_backward(imp, order):
+    if order == 'K' or (order == 'C' and imp.order == 'C'):
+        return False
+    elif order =='F' and imp.order == 'C':
+        return True
+    else:
+        raise NotImplementedError('not implemented yet')
+
 def get_external_loop_iter(space, order, arr, shape):
     imp = arr.implementation
-    if order == 'K' or (order == 'C' and imp.order == 'C'):
-        backward = False
-    elif order =='F' and imp.order == 'C':
-        backward = True
-    else:
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-                'not implemented yet'))
+
+    backward = is_backward(imp, order)
 
     return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward)
 
+class IndexIterator(object):
+    def __init__(self, shape, backward=False):
+        self.shape = shape
+        self.index = [0] * len(shape)
+        self.backward = backward
+        self.called = False
+
+    def next(self):
+        # TODO It's probably possible to refactor all the "next" method from each iterator
+        if not self.called:
+            self.called = True
+            return
+        for i in range(len(self.shape) - 1, -1, -1):
+            if self.index[i] < self.shape[i] - 1:
+                self.index[i] += 1
+                break
+            else:
+                self.index[i] = 0
+
+    def getvalue(self):
+        if not self.called:
+            return 0
+        if not self.backward:
+            ret = self.index[-1]
+            for i in range(len(self.shape) - 2, -1, -1):
+                ret += self.index[i] * self.shape[i - 1]
+        else:
+            ret = self.index[0]
+            for i in range(1, len(self.shape)):
+                ret += self.index[i] * self.shape[i - 1]
+        return ret
 
 class W_NDIter(W_Root):
 
         self.refs_ok = False
         self.reduce_ok = False
         self.zerosize_ok = False
+        self.index_iter = None
         if space.isinstance_w(w_seq, space.w_tuple) or \
            space.isinstance_w(w_seq, space.w_list):
             w_seq_as_list = space.listview(w_seq)
                                      len(self.seq), parse_op_flag)
         self.iters=[]
         self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
+        if self.tracked_index != "":
+            if self.order == "K":
+                self.order = self.seq[0].implementation.order
+            self.index_iter = IndexIterator(iter_shape, backward=self.order != self.tracked_index)
         if self.external_loop:
             for i in range(len(self.seq)):
                 self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
         else:
             raise OperationError(space.w_StopIteration, space.w_None)
         res = []
+        if self.index_iter:
+            self.index_iter.next()
         for i in range(len(self.iters)):
             res.append(self.iters[i].getitem(space, self.seq[i]))
             self.iters[i].next()
             'not implemented yet'))
 
     def descr_get_has_index(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        return space.wrap(not self.tracked_index == "")
 
     def descr_get_index(self, space):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            'not implemented yet'))
+        if self.tracked_index == "":
+            raise OperationError(space.w_ValueError, "Iterator does not have an index")
+        return space.wrap(self.index_iter.getvalue())
 
     def descr_get_has_multi_index(self, space):
         raise OperationError(space.w_NotImplementedError, space.wrap(

File pypy/module/micronumpy/test/test_nditer.py

             e = ex
         assert e
 
+    def test_index(self):
+        from numpypy import arange, nditer, zeros
+        a = arange(6).reshape(2,3)
+
+        r = []
+        it = nditer(a, flags=['c_index'])
+        assert it.has_index
+        for value in it:
+            r.append((value, it.index))
+        assert r == [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
+
+        r = []
+        it = nditer(a, flags=['f_index'])
+        assert it.has_index
+        for value in it:
+            r.append((value, it.index))
+        assert r == [(0, 0), (1, 2), (2, 4), (3, 1), (4, 3), (5, 5)]
+
     def test_interface(self):
         from numpypy import arange, nditer, zeros
         a = arange(6).reshape(2,3)