Commits

Brian Kearns committed 667ad75

simplify nditer

Comments (0)

Files changed (2)

pypy/module/micronumpy/iterators.py

         self.array.setitem(state.offset, elem)
 
 
-class SliceIterator(ArrayIter):
-    def __init__(self, arr, strides, backstrides, shape, order="C",
-                    backward=False, dtype=None):
-        if dtype is None:
-            dtype = arr.implementation.dtype
-        self.dtype = dtype
-        self.arr = arr
-        if backward:
-            self.slicesize = shape[0]
-            self.gap = [support.product(shape[1:]) * dtype.elsize]
-            strides = strides[1:]
-            backstrides = backstrides[1:]
-            shape = shape[1:]
-            strides.reverse()
-            backstrides.reverse()
-            shape.reverse()
-            size = support.product(shape)
-        else:
-            shape = [support.product(shape)]
-            strides, backstrides = calc_strides(shape, dtype, order)
-            size = 1
-            self.slicesize = support.product(shape)
-            self.gap = strides
-        ArrayIter.__init__(self, arr.implementation, size, shape, strides, backstrides)
-
-    def getslice(self):
-        from pypy.module.micronumpy.concrete import SliceArray
-        return SliceArray(self.offset, self.gap, self.backstrides,
-                          [self.slicesize], self.arr.implementation,
-                          self.arr, self.dtype)
-
-
 def AxisIter(array, shape, axis, cumulative):
     strides = array.get_strides()
     backstrides = array.get_backstrides()

pypy/module/micronumpy/nditer.py

 from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
 from pypy.module.micronumpy.descriptor import decode_w_dtype
-from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
+from pypy.module.micronumpy.iterators import ArrayIter
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
                                             shape_agreement, shape_agreement_multiple)
 
 
-class Iterator(object):
-    def __init__(self, nditer, index, it, op_flags):
-        self.nditer = nditer
-        self.index = index
-        self.it = it
-        self.st = it.reset()
-        self.op_flags = op_flags
-
-    def done(self):
-        return self.it.done(self.st)
-
-    def next(self):
-        self.st = self.it.next(self.st)
-
-    def getitem(self, space, array):
-        return self.op_flags.get_it_item[self.index](space, self.nditer, self.it, self.st)
-
-    def setitem(self, space, array, val):
-        xxx
-
-
 def parse_op_arg(space, name, w_op_flags, n, parse_one_arg):
-    ret = []
     if space.is_w(w_op_flags, space.w_None):
-        for i in range(n):
-            ret.append(OpFlag())
-    elif not space.isinstance_w(w_op_flags, space.w_tuple) and not \
+        w_op_flags = space.newtuple([space.wrap('readonly')])
+    if not space.isinstance_w(w_op_flags, space.w_tuple) and not \
             space.isinstance_w(w_op_flags, space.w_list):
         raise oefmt(space.w_ValueError,
                     '%s must be a tuple or array of per-op flag-tuples',
                     name)
+    ret = []
+    w_lst = space.listview(w_op_flags)
+    if space.isinstance_w(w_lst[0], space.w_tuple) or \
+       space.isinstance_w(w_lst[0], space.w_list):
+        if len(w_lst) != n:
+            raise oefmt(space.w_ValueError,
+                        '%s must be a tuple or array of per-op flag-tuples',
+                        name)
+        for item in w_lst:
+            ret.append(parse_one_arg(space, space.listview(item)))
     else:
-        w_lst = space.listview(w_op_flags)
-        if space.isinstance_w(w_lst[0], space.w_tuple) or \
-           space.isinstance_w(w_lst[0], space.w_list):
-            if len(w_lst) != n:
-                raise oefmt(space.w_ValueError,
-                            '%s must be a tuple or array of per-op flag-tuples',
-                            name)
-            for item in w_lst:
-                ret.append(parse_one_arg(space, space.listview(item)))
-        else:
-            op_flag = parse_one_arg(space, w_lst)
-            for i in range(n):
-                ret.append(op_flag)
+        op_flag = parse_one_arg(space, w_lst)
+        for i in range(n):
+            ret.append(op_flag)
     return ret
 
 
         self.native_byte_order = False
         self.tmp_copy = ''
         self.allocate = False
-        self.get_it_item = (get_readonly_item, get_readonly_slice)
-
-
-def get_readonly_item(space, nditer, it, st):
-    res = concrete.ConcreteNonWritableArrayWithBase(
-        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
-    res.start = st.offset
-    return W_NDimArray(res)
-
-
-def get_readwrite_item(space, nditer, it, st):
-    res = concrete.ConcreteArrayWithBase(
-        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
-    res.start = st.offset
-    return W_NDimArray(res)
-
-
-def get_readonly_slice(space, array, it):
-    return W_NDimArray(it.getslice().readonly())
-
-
-def get_readwrite_slice(space, array, it):
-    return W_NDimArray(it.getslice())
 
 
 def parse_op_flag(space, lst):
         else:
             raise OperationError(space.w_ValueError, space.wrap(
                 'op_flags must be a tuple or array of per-op flag-tuples'))
-        if op_flag.rw == '':
-            raise oefmt(space.w_ValueError,
-                        "None of the iterator flags READWRITE, READONLY, or "
-                        "WRITEONLY were specified for an operand")
-        elif op_flag.rw == 'r':
-            op_flag.get_it_item = (get_readonly_item, get_readonly_slice)
-        elif op_flag.rw == 'rw':
-            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
-        elif op_flag.rw == 'w':
-            # XXX Extra logic needed to make sure writeonly
-            op_flag.get_it_item = (get_readwrite_item, get_readwrite_slice)
+    if op_flag.rw == '':
+        raise oefmt(space.w_ValueError,
+                    "None of the iterator flags READWRITE, READONLY, or "
+                    "WRITEONLY were specified for an operand")
     return op_flag
 
 
     return ArrayIter(imp, imp.get_size(), shape, r[0], r[1])
 
 
-def get_external_loop_iter(space, order, arr, shape):
-    imp = arr.implementation
-    backward = is_backward(imp, order)
-    return SliceIterator(arr, imp.strides, imp.backstrides, shape, order=order, backward=backward)
-
-
 class IndexIterator(object):
     def __init__(self, shape, backward=False):
         self.shape = shape
                 out_dtype = None
                 for i in range(len(self.seq)):
                     if self.seq[i] is None:
-                        self.op_flags[i].get_it_item = (get_readwrite_item,
-                                                        get_readwrite_slice)
                         self.op_flags[i].allocate = True
                         continue
                     if self.op_flags[i].rw == 'w':
             self.dtypes = [s.get_dtype() for s in self.seq]
 
         # create an iterator for each operand
-        if self.external_loop:
-            for i in range(len(self.seq)):
-                self.iters.append(Iterator(
-                    self, 1,
-                    get_external_loop_iter(
-                        space, self.order, self.seq[i], iter_shape),
-                    self.op_flags[i]))
-        else:
-            for i in range(len(self.seq)):
-                self.iters.append(Iterator(
-                    self, 0,
-                    get_iter(
-                        space, self.order, self.seq[i], iter_shape, self.dtypes[i]),
-                    self.op_flags[i]))
+        for i in range(len(self.seq)):
+            it = get_iter(space, self.order, self.seq[i], iter_shape, self.dtypes[i])
+            self.iters.append((it, it.reset()))
 
     def set_op_axes(self, space, w_op_axes):
         if space.len_w(w_op_axes) != len(self.seq):
     def descr_iter(self, space):
         return space.wrap(self)
 
+    def getitem(self, it, st, op_flags):
+        if op_flags.rw == 'r':
+            impl = concrete.ConcreteNonWritableArrayWithBase
+        else:
+            impl = concrete.ConcreteArrayWithBase
+        res = impl([], it.array.dtype, it.array.order, [], [],
+                   it.array.storage, self)
+        res.start = st.offset
+        return W_NDimArray(res)
+
     def descr_getitem(self, space, w_idx):
         idx = space.int_w(w_idx)
         try:
-            ret = space.wrap(self.iters[idx].getitem(space, self.seq[idx]))
+            it, st = self.iters[idx]
         except IndexError:
             raise oefmt(space.w_IndexError,
                         "Iterator operand index %d is out of bounds", idx)
-        return ret
+        return self.getitem(it, st, self.op_flags[idx])
 
     def descr_setitem(self, space, w_idx, w_value):
         raise oefmt(space.w_NotImplementedError, "not implemented yet")
         space.wrap(len(self.iters))
 
     def descr_next(self, space):
-        for it in self.iters:
-            if not it.done():
+        for it, st in self.iters:
+            if not it.done(st):
                 break
         else:
             self.done = True
                 self.index_iter.next()
             else:
                 self.first_next = False
-        for i in range(len(self.iters)):
-            res.append(self.iters[i].getitem(space, self.seq[i]))
-            self.iters[i].next()
+        for i, (it, st) in enumerate(self.iters):
+            res.append(self.getitem(it, st, self.op_flags[i]))
+            self.iters[i] = (it, it.next(st))
         if len(res) < 2:
             return res[0]
         return space.newtuple(res)
     def iternext(self):
         if self.index_iter:
             self.index_iter.next()
-        for i in range(len(self.iters)):
-            self.iters[i].next()
-        for it in self.iters:
-            if not it.done():
+        for i, (it, st) in enumerate(self.iters):
+            self.iters[i] = (it, it.next(st))
+        for it, st in self.iters:
+            if not it.done(st):
                 break
         else:
             self.done = True