Commits

Romain Guillebert  committed 000f7ae Merge

Merge heads

  • Participants
  • Parent commits 7107aa4, 4bea52d
  • Branches numpypy-nditer

Comments (0)

Files changed (3)

File pypy/module/micronumpy/interp_nditer.py

         self.it.next()
 
     def getitem(self, space, array):
-        return self.op_flags.get_it_item(space, array, self.it)
+        return self.op_flags.get_it_item[self.index](space, array, self.it)
 
-class BoxIterator(IteratorMixin):
-    pass
+class BoxIterator(IteratorMixin, AbstractIterator):
+    index = 0
 
-class SliceIterator(IteratorMixin):
-    pass
+class ExternalLoopIterator(IteratorMixin, AbstractIterator):
+    index = 1
 
 def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
     ret = []
         self.native_byte_order = False
         self.tmp_copy = ''
         self.allocate = False
-        self.get_it_item = get_readonly_item
+        self.get_it_item = (get_readonly_item, get_readonly_slice)
 
 def get_readonly_item(space, array, it):
     return space.wrap(it.getitem())
             raise OperationError(space.w_ValueError, space.wrap(
                     'op_flags must be a tuple or array of per-op flag-tuples'))
         if op_flag.rw == 'r':
-            op_flag.get_it_item = get_readonly_item
+            op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
         elif op_flag.rw == 'rw':
-            op_flag.get_it_item = get_readwrite_item
+            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
     return op_flag
 
 def parse_func_flags(space, nditer, w_flags):
                 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
                 'multi-index is being tracked'))
 
-def get_iter(space, order, imp, shape):
+def get_iter(space, order, arr, shape):
+    imp = arr.implementation
     if order == 'K' or (order == 'C' and imp.order == 'C'):
         backward = False
     elif order =='F' and imp.order == 'C':
                                     shape, backward)
     return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
 
+def get_external_loop_iter(space, order, arr, shape):
+    imp = arr.implementation
+    if order == 'K' or (order == 'C' and imp.order == 'C'):
+        backward = False
+    elif order =='F' and imp.order == 'C':
+        backward = True
+    else:
+        raise OperationError(space.w_NotImplementedError, space.wrap(
+                'not implemented yet'))
+
+    return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward)
+
 
 class W_NDIter(W_Root):
 
         self.iters=[]
         self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
         if self.external_loop:
-            #XXX find longest contiguous shape
-            iter_shape = iter_shape[1:]
-        for i in range(len(self.seq)):
-            self.iters.append(BoxIterator(get_iter(space, self.order,
-                            self.seq[i].implementation, iter_shape), self.op_flags[i]))
+            for i in range(len(self.seq)):
+                self.iters.append(ExternalLoopIterator(get_external_loop_iter(space, self.order,
+                                self.seq[i], iter_shape), self.op_flags[i]))
+        else:
+            for i in range(len(self.seq)):
+                self.iters.append(BoxIterator(get_iter(space, self.order,
+                                self.seq[i], iter_shape), self.op_flags[i]))
 
     def descr_iter(self, space):
         return space.wrap(self)

File pypy/module/micronumpy/iter.py

      calculate_slice_strides
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.arrayimpl import base
+from pypy.module.micronumpy import support
 from rpython.rlib import jit
 
 # structures to describe slicing
         self.offset %= self.size
 
 class SliceIterator(object):
-    def __init__(self, arr, stride, backstride, shape, dtype=None):
-        self.step = 0
+    def __init__(self, arr, strides, backstrides, shape, order="C", backward=False, dtype=None):
+        self.indexes = [0] * (len(shape) - 1)
+        self.offset = 0
         self.arr = arr
-        self.stride = stride
-        self.backstride = backstride
-        self.shape = shape
         if dtype is None:
             dtype = arr.implementation.dtype
+        if backward:
+            self.slicesize = shape[0]
+            self.gap = [support.product(shape[1:]) * dtype.get_size()]
+            self.strides = strides[1:][::-1]
+            self.backstrides = backstrides[1:][::-1]
+            self.shape = shape[1:][::-1]
+            self.shapelen = len(self.shape)
+        else:
+            shape = [support.product(shape)]
+            self.strides, self.backstrides = support.calc_strides(shape, dtype, order)
+            self.slicesize = support.product(shape)
+            self.shapelen = 0
+            self.gap = self.strides
         self.dtype = dtype
         self._done = False
 
-    def done():
+    def done(self):
         return self._done
 
-    def next():
-        self.step += self.arr.implementation.dtype.get_size()
-        if self.step == self.backstride - self.implementation.dtype.get_size():
+    @jit.unroll_safe
+    def next(self):
+        offset = self.offset
+        for i in range(self.shapelen - 1, -1, -1):
+            if self.indexes[i] < self.shape[i] - 1:
+                self.indexes[i] += 1
+                offset += self.strides[i]
+                break
+            else:
+                self.indexes[i] = 0
+                offset -= self.backstrides[i]
+        else:
             self._done = True
+        self.offset = offset
 
     def getslice(self):
         from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
-        return SliceArray(self.step, [self.stride], [self.backstride], self.shape, self.arr.implementation, self.arr, self.dtype)
+        return SliceArray(self.offset, self.gap, self.backstrides, [self.slicesize], self.arr.implementation, self.arr, self.dtype)
 
 class AxisIterator(base.BaseArrayIterator):
     def __init__(self, array, shape, dim, cumultative):

File pypy/module/micronumpy/test/test_nditer.py

 
     def test_external_loop(self):
         from numpypy import arange, nditer, array
-        a = arange(12).reshape(2,3,2)
+        a = arange(24).reshape(2, 3, 4)
         r = []
         n = 0
         for x in nditer(a, flags=['external_loop']):
             r.append(x)
             n += 1
         assert n == 1
-        assert (array(r) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]).all()
+        assert (array(r) == range(24)).all()
         r = []
         n = 0
         for x in nditer(a, flags=['external_loop'], order='F'):
             r.append(x)
             n += 1
-        assert n == 6
-        assert (array(r) == [[0, 6], [2, 8], [4, 10], [1, 7], [3, 9], [5, 11]]).all()
+        assert n == 12
+        assert (array(r) == [[ 0, 12], [ 4, 16], [ 8, 20], [ 1, 13], [ 5, 17], [ 9, 21], [ 2, 14], [ 6, 18], [10, 22], [ 3, 15], [ 7, 19], [11, 23]]).all()
 
     def test_interface(self):
         from numpypy import arange, nditer, zeros