Commits

Brian Kearns committed f48615b Merge

merge numpy-speed

Comments (0)

Files changed (10)

pypy/module/micronumpy/concrete.py

                                             self.get_backstrides(),
                                             self.get_shape(), shape,
                                             backward_broadcast)
-            return ArrayIter(self, support.product(shape), shape, r[0], r[1])
-        return ArrayIter(self, self.get_size(), self.shape,
-                         self.strides, self.backstrides)
+            i = ArrayIter(self, support.product(shape), shape, r[0], r[1])
+        else:
+            i = ArrayIter(self, self.get_size(), self.shape,
+                          self.strides, self.backstrides)
+        return i, i.reset()
 
     def swapaxes(self, space, orig_arr, axis1, axis2):
         shape = self.get_shape()[:]

pypy/module/micronumpy/ctors.py

             "string is smaller than requested size"))
 
     a = W_NDimArray.from_shape(space, [num_items], dtype=dtype)
-    ai = a.create_iter()
+    ai, state = a.create_iter()
     for val in items:
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
 
     return space.wrap(a)
 

pypy/module/micronumpy/flatiter.py

         self.reset()
 
     def reset(self):
-        self.iter = self.base.create_iter()
+        self.iter, self.state = self.base.create_iter()
 
     def descr_len(self, space):
         return space.wrap(self.base.get_size())
 
     def descr_next(self, space):
-        if self.iter.done():
+        if self.iter.done(self.state):
             raise OperationError(space.w_StopIteration, space.w_None)
-        w_res = self.iter.getitem()
-        self.iter.next()
+        w_res = self.iter.getitem(self.state)
+        self.state = self.iter.next(self.state)
         return w_res
 
     def descr_index(self, space):
-        return space.wrap(self.iter.index)
+        return space.wrap(self.state.index)
 
     def descr_coords(self, space):
-        coords = self.base.to_coords(space, space.wrap(self.iter.index))
+        coords = self.base.to_coords(space, space.wrap(self.state.index))
         return space.newtuple([space.wrap(c) for c in coords])
 
     def descr_getitem(self, space, w_idx):
         self.reset()
         base = self.base
         start, stop, step, length = space.decode_index4(w_idx, base.get_size())
-        base_iter = base.create_iter()
-        base_iter.next_skip_x(start)
+        base_iter, base_state = base.create_iter()
+        base_state = base_iter.next_skip_x(base_state, start)
         if length == 1:
-            return base_iter.getitem()
+            return base_iter.getitem(base_state)
         res = W_NDimArray.from_shape(space, [length], base.get_dtype(),
                                      base.get_order(), w_instance=base)
-        return loop.flatiter_getitem(res, base_iter, step)
+        return loop.flatiter_getitem(res, base_iter, base_state, step)
 
     def descr_setitem(self, space, w_idx, w_value):
         if not (space.isinstance_w(w_idx, space.w_int) or

pypy/module/micronumpy/iterators.py

         self.shapelen = len(shape)
         self.indexes = [0] * len(shape)
         self._done = False
-        self.idx_w = [None] * len(idx_w)
+        self.idx_w_i = [None] * len(idx_w)
+        self.idx_w_s = [None] * len(idx_w)
         for i, w_idx in enumerate(idx_w):
             if isinstance(w_idx, W_NDimArray):
-                self.idx_w[i] = w_idx.create_iter(shape)
+                self.idx_w_i[i], self.idx_w_s[i] = w_idx.create_iter(shape)
 
     def done(self):
         return self._done
 
     @jit.unroll_safe
     def next(self):
-        for w_idx in self.idx_w:
-            if w_idx is not None:
-                w_idx.next()
+        for i, idx_w_i in enumerate(self.idx_w_i):
+            if idx_w_i is not None:
+                self.idx_w_s[i] = idx_w_i.next(self.idx_w_s[i])
         for i in range(self.shapelen - 1, -1, -1):
             if self.indexes[i] < self.shape[i] - 1:
                 self.indexes[i] += 1
         return [space.wrap(self.indexes[i]) for i in range(shapelen)]
 
 
+class IterState(object):
+    _immutable_fields_ = ['index', 'indices[*]', 'offset']
+
+    def __init__(self, index, indices, offset):
+        self.index = index
+        self.indices = indices
+        self.offset = offset
+
+
 class ArrayIter(object):
     _immutable_fields_ = ['array', 'size', 'ndim_m1', 'shape_m1[*]',
                           'strides[*]', 'backstrides[*]']
         self.strides = strides
         self.backstrides = backstrides
 
-        self.index = 0
-        self.indices = [0] * len(shape)
-        self.offset = array.start
+    def reset(self):
+        return IterState(0, [0] * len(self.shape_m1), self.array.start)
 
     @jit.unroll_safe
-    def reset(self):
-        self.index = 0
+    def next(self, state):
+        index = state.index + 1
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            self.indices[i] = 0
-        self.offset = self.array.start
+            idx = indices[i]
+            if idx < self.shape_m1[i]:
+                indices[i] = idx + 1
+                offset += self.strides[i]
+                break
+            else:
+                indices[i] = 0
+                offset -= self.backstrides[i]
+        return IterState(index, indices, offset)
 
     @jit.unroll_safe
-    def next(self):
-        self.index += 1
+    def next_skip_x(self, state, step):
+        assert step >= 0
+        if step == 0:
+            return state
+        index = state.index + step
+        indices = state.indices
+        offset = state.offset
         for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
-            if idx < self.shape_m1[i]:
-                self.indices[i] = idx + 1
-                self.offset += self.strides[i]
+            idx = indices[i]
+            if idx < (self.shape_m1[i] + 1) - step:
+                indices[i] = idx + step
+                offset += self.strides[i] * step
                 break
             else:
-                self.indices[i] = 0
-                self.offset -= self.backstrides[i]
-
-    @jit.unroll_safe
-    def next_skip_x(self, step):
-        assert step >= 0
-        if step == 0:
-            return
-        self.index += step
-        for i in xrange(self.ndim_m1, -1, -1):
-            idx = self.indices[i]
-            if idx < (self.shape_m1[i] + 1) - step:
-                self.indices[i] = idx + step
-                self.offset += self.strides[i] * step
-                break
-            else:
-                rem_step = (self.indices[i] + step) // (self.shape_m1[i] + 1)
+                rem_step = (idx + step) // (self.shape_m1[i] + 1)
                 cur_step = step - rem_step * (self.shape_m1[i] + 1)
-                self.indices[i] += cur_step
-                self.offset += self.strides[i] * cur_step
+                indices[i] = idx + cur_step
+                offset += self.strides[i] * cur_step
                 step = rem_step
                 assert step > 0
+        return IterState(index, indices, offset)
 
-    def done(self):
-        return self.index >= self.size
+    def done(self, state):
+        return state.index >= self.size
 
-    def getitem(self):
-        return self.array.getitem(self.offset)
+    def getitem(self, state):
+        return self.array.getitem(state.offset)
 
-    def getitem_bool(self):
-        return self.array.getitem_bool(self.offset)
+    def getitem_bool(self, state):
+        return self.array.getitem_bool(state.offset)
 
-    def setitem(self, elem):
-        self.array.setitem(self.offset, elem)
+    def setitem(self, state, elem):
+        self.array.setitem(state.offset, elem)
 
 
 class SliceIterator(ArrayIter):

pypy/module/micronumpy/loop.py

     AllButAxisIter
 
 
-call2_driver = jit.JitDriver(name='numpy_call2',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
-                                     'left_iter', 'right_iter', 'out_iter'])
+call2_driver = jit.JitDriver(
+    name='numpy_call2',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call2(space, shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     # handle array_priority
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype,
                                      w_instance=lhs_for_subtype)
-    left_iter = w_lhs.create_iter(shape)
-    right_iter = w_rhs.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    left_iter, left_state = w_lhs.create_iter(shape)
+    right_iter, right_state = w_rhs.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call2_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
-                                     out=out,
-                                     left_iter=left_iter, right_iter=right_iter,
-                                     out_iter=out_iter)
-        w_left = left_iter.getitem().convert_to(space, calc_dtype)
-        w_right = right_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        w_left = left_iter.getitem(left_state).convert_to(space, calc_dtype)
+        w_right = right_iter.getitem(right_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, w_left, w_right).convert_to(
             space, res_dtype))
-        left_iter.next()
-        right_iter.next()
-        out_iter.next()
+        left_state = left_iter.next(left_state)
+        right_state = right_iter.next(right_state)
+        out_state = out_iter.next(out_state)
     return out
 
-call1_driver = jit.JitDriver(name='numpy_call1',
-                             greens = ['shapelen', 'func', 'calc_dtype',
-                                       'res_dtype'],
-                             reds = ['shape', 'w_obj', 'out', 'obj_iter',
-                                     'out_iter'])
+call1_driver = jit.JitDriver(
+    name='numpy_call1',
+    greens=['shapelen', 'func', 'calc_dtype', 'res_dtype'],
+    reds='auto')
 
 def call1(space, shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(space, shape, res_dtype, w_instance=w_obj)
-    obj_iter = w_obj.create_iter(shape)
-    out_iter = out.create_iter(shape)
+    obj_iter, obj_state = w_obj.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         call1_driver.jit_merge_point(shapelen=shapelen, func=func,
-                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
-                                     shape=shape, w_obj=w_obj, out=out,
-                                     obj_iter=obj_iter, out_iter=out_iter)
-        elem = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(func(calc_dtype, elem).convert_to(space, res_dtype))
-        out_iter.next()
-        obj_iter.next()
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype)
+        elem = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, func(calc_dtype, elem).convert_to(space, res_dtype))
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     return out
 
 setslice_driver = jit.JitDriver(name='numpy_setslice',
 def setslice(space, shape, target, source):
     # note that unlike everything else, target and source here are
     # array implementations, not arrays
-    target_iter = target.create_iter(shape)
-    source_iter = source.create_iter(shape)
+    target_iter, target_state = target.create_iter(shape)
+    source_iter, source_state = source.create_iter(shape)
     dtype = target.dtype
     shapelen = len(shape)
-    while not target_iter.done():
+    while not target_iter.done(target_state):
         setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        val = source_iter.getitem(source_state)
         if dtype.is_str_or_unicode():
-            target_iter.setitem(dtype.coerce(space, source_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            target_iter.setitem(source_iter.getitem().convert_to(space, dtype))
-        target_iter.next()
-        source_iter.next()
+            val = val.convert_to(space, dtype)
+        target_iter.setitem(target_state, val)
+        target_state = target_iter.next(target_state)
+        source_state = source_iter.next(source_state)
     return target
 
 reduce_driver = jit.JitDriver(name='numpy_reduce',
                               reds = 'auto')
 
 def compute_reduce(space, obj, calc_dtype, func, done_func, identity):
-    obj_iter = obj.create_iter()
+    obj_iter, obj_state = obj.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
                                       done_func=done_func,
                                       calc_dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
         cur_value = func(calc_dtype, cur_value, rval)
-        obj_iter.next()
+        obj_state = obj_iter.next(obj_state)
     return cur_value
 
 reduce_cum_driver = jit.JitDriver(name='numpy_reduce_cum_driver',
                                   reds = 'auto')
 
 def compute_reduce_cumulative(space, obj, out, calc_dtype, func, identity):
-    obj_iter = obj.create_iter()
-    out_iter = out.create_iter()
+    obj_iter, obj_state = obj.create_iter()
+    out_iter, out_state = out.create_iter()
     if identity is None:
-        cur_value = obj_iter.getitem().convert_to(space, calc_dtype)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        cur_value = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
     else:
         cur_value = identity.convert_to(space, calc_dtype)
     shapelen = len(obj.get_shape())
-    while not obj_iter.done():
+    while not obj_iter.done(obj_state):
         reduce_cum_driver.jit_merge_point(shapelen=shapelen, func=func,
                                           dtype=calc_dtype)
-        rval = obj_iter.getitem().convert_to(space, calc_dtype)
+        rval = obj_iter.getitem(obj_state).convert_to(space, calc_dtype)
         cur_value = func(calc_dtype, cur_value, rval)
-        out_iter.setitem(cur_value)
-        out_iter.next()
-        obj_iter.next()
+        out_iter.setitem(out_state, cur_value)
+        out_state = out_iter.next(out_state)
+        obj_state = obj_iter.next(obj_state)
 
 def fill(arr, box):
-    arr_iter = arr.create_iter()
-    while not arr_iter.done():
-        arr_iter.setitem(box)
-        arr_iter.next()
+    arr_iter, arr_state = arr.create_iter()
+    while not arr_iter.done(arr_state):
+        arr_iter.setitem(arr_state, box)
+        arr_state = arr_iter.next(arr_state)
 
 def assign(space, arr, seq):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     for item in seq:
-        arr_iter.setitem(arr_dtype.coerce(space, item))
-        arr_iter.next()
+        arr_iter.setitem(arr_state, arr_dtype.coerce(space, item))
+        arr_state = arr_iter.next(arr_state)
 
 where_driver = jit.JitDriver(name='numpy_where',
                              greens = ['shapelen', 'dtype', 'arr_dtype'],
                              reds = 'auto')
 
 def where(space, out, shape, arr, x, y, dtype):
-    out_iter = out.create_iter(shape)
-    arr_iter = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     arr_dtype = arr.get_dtype()
-    x_iter = x.create_iter(shape)
-    y_iter = y.create_iter(shape)
+    x_iter, x_state = x.create_iter(shape)
+    y_iter, y_state = y.create_iter(shape)
     if x.is_scalar():
         if y.is_scalar():
-            iter = arr_iter
+            iter, state = arr_iter, arr_state
         else:
-            iter = y_iter
+            iter, state = y_iter, y_state
     else:
-        iter = x_iter
+        iter, state = x_iter, x_state
     shapelen = len(shape)
-    while not iter.done():
+    while not iter.done(state):
         where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                         arr_dtype=arr_dtype)
-        w_cond = arr_iter.getitem()
+        w_cond = arr_iter.getitem(arr_state)
         if arr_dtype.itemtype.bool(w_cond):
-            w_val = x_iter.getitem().convert_to(space, dtype)
+            w_val = x_iter.getitem(x_state).convert_to(space, dtype)
         else:
-            w_val = y_iter.getitem().convert_to(space, dtype)
-        out_iter.setitem(w_val)
-        out_iter.next()
-        arr_iter.next()
-        x_iter.next()
-        y_iter.next()
+            w_val = y_iter.getitem(y_state).convert_to(space, dtype)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
+        x_state = x_iter.next(x_state)
+        y_state = y_iter.next(y_state)
+        if x.is_scalar():
+            if y.is_scalar():
+                state = arr_state
+            else:
+                state = y_state
+        else:
+            state = x_state
     return out
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
 def do_axis_reduce(space, shape, func, arr, dtype, axis, out, identity, cumulative,
                    temp):
     out_iter = AxisIter(out.implementation, arr.get_shape(), axis, cumulative)
+    out_state = out_iter.reset()
     if cumulative:
         temp_iter = AxisIter(temp.implementation, arr.get_shape(), axis, False)
+        temp_state = temp_iter.reset()
     else:
-        temp_iter = out_iter # hack
-    arr_iter = arr.create_iter()
+        temp_iter = out_iter  # hack
+        temp_state = out_state
+    arr_iter, arr_state = arr.create_iter()
     if identity is not None:
         identity = identity.convert_to(space, dtype)
     shapelen = len(shape)
-    while not out_iter.done():
+    while not out_iter.done(out_state):
         axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
                                             dtype=dtype)
-        assert not arr_iter.done()
-        w_val = arr_iter.getitem().convert_to(space, dtype)
-        if out_iter.indices[axis] == 0:
+        assert not arr_iter.done(arr_state)
+        w_val = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        if out_state.indices[axis] == 0:
             if identity is not None:
                 w_val = func(dtype, identity, w_val)
         else:
-            cur = temp_iter.getitem()
+            cur = temp_iter.getitem(temp_state)
             w_val = func(dtype, cur, w_val)
-        out_iter.setitem(w_val)
+        out_iter.setitem(out_state, w_val)
+        out_state = out_iter.next(out_state)
         if cumulative:
-            temp_iter.setitem(w_val)
-            temp_iter.next()
-        arr_iter.next()
-        out_iter.next()
+            temp_iter.setitem(temp_state, w_val)
+            temp_state = temp_iter.next(temp_state)
+        else:
+            temp_state = out_state
+        arr_state = arr_iter.next(arr_state)
     return out
 
 
         result = 0
         idx = 1
         dtype = arr.get_dtype()
-        iter = arr.create_iter()
-        cur_best = iter.getitem()
-        iter.next()
+        iter, state = arr.create_iter()
+        cur_best = iter.getitem(state)
+        state = iter.next(state)
         shapelen = len(arr.get_shape())
-        while not iter.done():
+        while not iter.done(state):
             arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-            w_val = iter.getitem()
+            w_val = iter.getitem(state)
             new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
             if dtype.itemtype.ne(new_best, cur_best):
                 result = idx
                 cur_best = new_best
-            iter.next()
+            state = iter.next(state)
             idx += 1
         return result
     return argmin_argmax
     right_impl = right.implementation
     assert left_shape[-1] == right_shape[right_critical_dim]
     assert result.get_dtype() == dtype
-    outi = result.create_iter()
+    outi, outs = result.create_iter()
     lefti = AllButAxisIter(left_impl, len(left_shape) - 1)
     righti = AllButAxisIter(right_impl, right_critical_dim)
+    lefts = lefti.reset()
+    rights = righti.reset()
     n = left_impl.shape[-1]
     s1 = left_impl.strides[-1]
     s2 = right_impl.strides[right_critical_dim]
-    while not lefti.done():
-        while not righti.done():
-            oval = outi.getitem()
-            i1 = lefti.offset
-            i2 = righti.offset
+    while not lefti.done(lefts):
+        while not righti.done(rights):
+            oval = outi.getitem(outs)
+            i1 = lefts.offset
+            i2 = rights.offset
             i = 0
             while i < n:
                 i += 1
                 oval = dtype.itemtype.add(oval, dtype.itemtype.mul(lval, rval))
                 i1 += s1
                 i2 += s2
-            outi.setitem(oval)
-            outi.next()
-            righti.next()
-        righti.reset()
-        lefti.next()
+            outi.setitem(outs, oval)
+            outs = outi.next(outs)
+            rights = righti.next(rights)
+        rights = righti.reset()
+        lefts = lefti.next(lefts)
     return result
 
 count_all_true_driver = jit.JitDriver(name = 'numpy_count',
 
 def count_all_true_concrete(impl):
     s = 0
-    iter = impl.create_iter()
+    iter, state = impl.create_iter()
     shapelen = len(impl.shape)
     dtype = impl.dtype
-    while not iter.done():
+    while not iter.done(state):
         count_all_true_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        s += iter.getitem_bool()
-        iter.next()
+        s += iter.getitem_bool(state)
+        state = iter.next(state)
     return s
 
 def count_all_true(arr):
                                reds = 'auto')
 
 def nonzero(res, arr, box):
-    res_iter = res.create_iter()
-    arr_iter = arr.create_iter()
+    res_iter, res_state = res.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.shape)
     dtype = arr.dtype
     dims = range(shapelen)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         nonzero_driver.jit_merge_point(shapelen=shapelen, dims=dims, dtype=dtype)
-        if arr_iter.getitem_bool():
+        if arr_iter.getitem_bool(arr_state):
             for d in dims:
-                res_iter.setitem(box(arr_iter.indices[d]))
-                res_iter.next()
-        arr_iter.next()
+                res_iter.setitem(res_state, box(arr_state.indices[d]))
+                res_state = res_iter.next(res_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 
                                       reds = 'auto')
 
 def getitem_filter(res, arr, index):
-    res_iter = res.create_iter()
+    res_iter, res_state = res.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
-    arr_iter = arr.create_iter()
+        index_iter, index_state = index.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                               )
-        if index_iter.getitem_bool():
-            res_iter.setitem(arr_iter.getitem())
-            res_iter.next()
-        index_iter.next()
-        arr_iter.next()
+        if index_iter.getitem_bool(index_state):
+            res_iter.setitem(res_state, arr_iter.getitem(arr_state))
+            res_state = res_iter.next(res_state)
+        index_state = index_iter.next(index_state)
+        arr_state = arr_iter.next(arr_state)
     return res
 
 setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
                                       reds = 'auto')
 
 def setitem_filter(space, arr, index, value):
-    arr_iter = arr.create_iter()
+    arr_iter, arr_state = arr.create_iter()
     shapelen = len(arr.get_shape())
     if shapelen > 1 and len(index.get_shape()) < 2:
-        index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
+        index_iter, index_state = index.create_iter(arr.get_shape(), backward_broadcast=True)
     else:
-        index_iter = index.create_iter()
+        index_iter, index_state = index.create_iter()
     if value.get_size() == 1:
-        value_iter = value.create_iter(arr.get_shape())
+        value_iter, value_state = value.create_iter(arr.get_shape())
     else:
-        value_iter = value.create_iter()
+        value_iter, value_state = value.create_iter()
     index_dtype = index.get_dtype()
     arr_dtype = arr.get_dtype()
-    while not index_iter.done():
+    while not index_iter.done(index_state):
         setitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
                                              )
-        if index_iter.getitem_bool():
-            arr_iter.setitem(arr_dtype.coerce(space, value_iter.getitem()))
-            value_iter.next()
-        arr_iter.next()
-        index_iter.next()
+        if index_iter.getitem_bool(index_state):
+            val = arr_dtype.coerce(space, value_iter.getitem(value_state))
+            value_state = value_iter.next(value_state)
+            arr_iter.setitem(arr_state, val)
+        arr_state = arr_iter.next(arr_state)
+        index_state = index_iter.next(index_state)
 
 flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
                                         greens = ['dtype'],
                                         reds = 'auto')
 
-def flatiter_getitem(res, base_iter, step):
-    ri = res.create_iter()
+def flatiter_getitem(res, base_iter, base_state, step):
+    ri, rs = res.create_iter()
     dtype = res.get_dtype()
-    while not ri.done():
+    while not ri.done(rs):
         flatiter_getitem_driver.jit_merge_point(dtype=dtype)
-        ri.setitem(base_iter.getitem())
-        base_iter.next_skip_x(step)
-        ri.next()
+        ri.setitem(rs, base_iter.getitem(base_state))
+        base_state = base_iter.next_skip_x(base_state, step)
+        rs = ri.next(rs)
     return res
 
 flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
 
 def flatiter_setitem(space, arr, val, start, step, length):
     dtype = arr.get_dtype()
-    arr_iter = arr.create_iter()
-    val_iter = val.create_iter()
-    arr_iter.next_skip_x(start)
+    arr_iter, arr_state = arr.create_iter()
+    val_iter, val_state = val.create_iter()
+    arr_state = arr_iter.next_skip_x(arr_state, start)
     while length > 0:
         flatiter_setitem_driver.jit_merge_point(dtype=dtype)
+        val = val_iter.getitem(val_state)
         if dtype.is_str_or_unicode():
-            arr_iter.setitem(dtype.coerce(space, val_iter.getitem()))
+            val = dtype.coerce(space, val)
         else:
-            arr_iter.setitem(val_iter.getitem().convert_to(space, dtype))
+            val = val.convert_to(space, dtype)
+        arr_iter.setitem(arr_state, val)
         # need to repeat i_nput values until all assignments are done
-        arr_iter.next_skip_x(step)
+        arr_state = arr_iter.next_skip_x(arr_state, step)
+        val_state = val_iter.next(val_state)
         length -= 1
-        val_iter.next()
 
 fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
                                   greens = ['itemsize', 'dtype'],
 
 def fromstring_loop(space, a, dtype, itemsize, s):
     i = 0
-    ai = a.create_iter()
-    while not ai.done():
+    ai, state = a.create_iter()
+    while not ai.done(state):
         fromstring_driver.jit_merge_point(dtype=dtype, itemsize=itemsize)
         sub = s[i*itemsize:i*itemsize + itemsize]
         if dtype.is_str_or_unicode():
             val = dtype.coerce(space, space.wrap(sub))
         else:
             val = dtype.itemtype.runpack_str(space, sub)
-        ai.setitem(val)
-        ai.next()
+        ai.setitem(state, val)
+        state = ai.next(state)
         i += 1
 
 def tostring(space, arr):
     builder = StringBuilder()
-    iter = arr.create_iter()
+    iter, state = arr.create_iter()
     w_res_str = W_NDimArray.from_shape(space, [1], arr.get_dtype(), order='C')
     itemsize = arr.get_dtype().elsize
     res_str_casted = rffi.cast(rffi.CArrayPtr(lltype.Char),
                                w_res_str.implementation.get_storage_as_int(space))
-    while not iter.done():
-        w_res_str.implementation.setitem(0, iter.getitem())
+    while not iter.done(state):
+        w_res_str.implementation.setitem(0, iter.getitem(state))
         for i in range(itemsize):
             builder.append(res_str_casted[i])
-        iter.next()
+        state = iter.next(state)
     return builder.build()
 
 getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         res.descr_setitem(space, space.newtuple(prefix_w[:prefixlen] +
         # prepare the index
         index_w = [None] * indexlen
         for i in range(indexlen):
-            if iter.idx_w[i] is not None:
-                index_w[i] = iter.idx_w[i].getitem()
+            if iter.idx_w_i[i] is not None:
+                index_w[i] = iter.idx_w_i[i].getitem(iter.idx_w_s[i])
             else:
                 index_w[i] = indexes_w[i]
         w_idx = space.newtuple(prefix_w[:prefixlen] + iter.get_index(space,
 
 def byteswap(from_, to):
     dtype = from_.dtype
-    from_iter = from_.create_iter()
-    to_iter = to.create_iter()
-    while not from_iter.done():
+    from_iter, from_state = from_.create_iter()
+    to_iter, to_state = to.create_iter()
+    while not from_iter.done(from_state):
         byteswap_driver.jit_merge_point(dtype=dtype)
-        to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
-        to_iter.next()
-        from_iter.next()
+        val = dtype.itemtype.byteswap(from_iter.getitem(from_state))
+        to_iter.setitem(to_state, val)
+        to_state = to_iter.next(to_state)
+        from_state = from_iter.next(from_state)
 
 choose_driver = jit.JitDriver(name='numpy_choose_driver',
                               greens = ['shapelen', 'mode', 'dtype'],
 
 def choose(space, arr, choices, shape, dtype, out, mode):
     shapelen = len(shape)
-    iterators = [a.create_iter(shape) for a in choices]
-    arr_iter = arr.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    pairs = [a.create_iter(shape) for a in choices]
+    iterators = [i[0] for i in pairs]
+    states = [i[1] for i in pairs]
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                       mode=mode)
-        index = support.index_w(space, arr_iter.getitem())
+        index = support.index_w(space, arr_iter.getitem(arr_state))
         if index < 0 or index >= len(iterators):
             if mode == NPY.RAISE:
                 raise OperationError(space.w_ValueError, space.wrap(
                     index = 0
                 else:
                     index = len(iterators) - 1
-        out_iter.setitem(iterators[index].getitem().convert_to(space, dtype))
-        for iter in iterators:
-            iter.next()
-        out_iter.next()
-        arr_iter.next()
+        val = iterators[index].getitem(states[index]).convert_to(space, dtype)
+        out_iter.setitem(out_state, val)
+        for i in range(len(iterators)):
+            states[i] = iterators[i].next(states[i])
+        out_state = out_iter.next(out_state)
+        arr_state = arr_iter.next(arr_state)
 
 clip_driver = jit.JitDriver(name='numpy_clip_driver',
                             greens = ['shapelen', 'dtype'],
                             reds = 'auto')
 
 def clip(space, arr, shape, min, max, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
     dtype = out.get_dtype()
     shapelen = len(shape)
-    min_iter = min.create_iter(shape)
-    max_iter = max.create_iter(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    min_iter, min_state = min.create_iter(shape)
+    max_iter, max_state = max.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
+    while not arr_iter.done(arr_state):
         clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
-        w_min = min_iter.getitem().convert_to(space, dtype)
-        w_max = max_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
+        w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+        w_max = max_iter.getitem(max_state).convert_to(space, dtype)
         if dtype.itemtype.lt(w_v, w_min):
             w_v = w_min
         elif dtype.itemtype.gt(w_v, w_max):
             w_v = w_max
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        max_iter.next()
-        out_iter.next()
-        min_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        min_state = min_iter.next(min_state)
+        max_state = max_iter.next(max_state)
+        out_state = out_iter.next(out_state)
 
 round_driver = jit.JitDriver(name='numpy_round_driver',
                              greens = ['shapelen', 'dtype'],
                              reds = 'auto')
 
 def round(space, arr, dtype, shape, decimals, out):
-    arr_iter = arr.create_iter(shape)
+    arr_iter, arr_state = arr.create_iter(shape)
+    out_iter, out_state = out.create_iter(shape)
     shapelen = len(shape)
-    out_iter = out.create_iter(shape)
-    while not arr_iter.done():
+    while not arr_iter.done(arr_state):
         round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
-        w_v = arr_iter.getitem().convert_to(space, dtype)
+        w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
         w_v = dtype.itemtype.round(w_v, decimals)
-        out_iter.setitem(w_v)
-        arr_iter.next()
-        out_iter.next()
+        out_iter.setitem(out_state, w_v)
+        arr_state = arr_iter.next(arr_state)
+        out_state = out_iter.next(out_state)
 
 diagonal_simple_driver = jit.JitDriver(name='numpy_diagonal_simple_driver',
                                        greens = ['axis1', 'axis2'],
                                        reds = 'auto')
 
 def diagonal_simple(space, arr, out, offset, axis1, axis2, size):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     i = 0
     index = [0] * 2
     while i < size:
         diagonal_simple_driver.jit_merge_point(axis1=axis1, axis2=axis2)
         index[axis1] = i
         index[axis2] = i + offset
-        out_iter.setitem(arr.getitem_index(space, index))
+        out_iter.setitem(out_state, arr.getitem_index(space, index))
         i += 1
-        out_iter.next()
+        out_state = out_iter.next(out_state)
 
 def diagonal_array(space, arr, out, offset, axis1, axis2, shape):
-    out_iter = out.create_iter()
+    out_iter, out_state = out.create_iter()
     iter = PureShapeIter(shape, [])
     shapelen_minus_1 = len(shape) - 1
     assert shapelen_minus_1 >= 0
             indexes = (iter.indexes[:a] + [last_index + offset] +
                        iter.indexes[a:b] + [last_index] +
                        iter.indexes[b:shapelen_minus_1])
-        out_iter.setitem(arr.getitem_index(space, indexes))
+        out_iter.setitem(out_state, arr.getitem_index(space, indexes))
         iter.next()
-        out_iter.next()
+        out_state = out_iter.next(out_state)

pypy/module/micronumpy/ndarray.py

         return space.call_function(cache.w_array_str, self)
 
     def dump_data(self, prefix='array(', separator=',', suffix=')'):
-        i = self.create_iter()
+        i, state = self.create_iter()
         first = True
         dtype = self.get_dtype()
         s = StringBuilder()
         s.append(prefix)
         if not self.is_scalar():
             s.append('[')
-        while not i.done():
+        while not i.done(state):
             if first:
                 first = False
             else:
                 s.append(separator)
                 s.append(' ')
             if self.is_scalar() and dtype.is_str():
-                s.append(dtype.itemtype.to_str(i.getitem()))
+                s.append(dtype.itemtype.to_str(i.getitem(state)))
             else:
-                s.append(dtype.itemtype.str_format(i.getitem()))
-            i.next()
+                s.append(dtype.itemtype.str_format(i.getitem(state)))
+            state = i.next(state)
         if not self.is_scalar():
             s.append(']')
         s.append(suffix)
         if self.get_size() > 1:
             raise OperationError(space.w_ValueError, space.wrap(
                 "The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()"))
-        iter = self.create_iter()
-        return space.wrap(space.is_true(iter.getitem()))
+        iter, state = self.create_iter()
+        return space.wrap(space.is_true(iter.getitem(state)))
 
     def _binop_impl(ufunc_name):
         def impl(self, space, w_other, w_out=None):
 
         builder = StringBuilder()
         if isinstance(self.implementation, SliceArray):
-            iter = self.implementation.create_iter()
-            while not iter.done():
-                box = iter.getitem()
+            iter, state = self.implementation.create_iter()
+            while not iter.done(state):
+                box = iter.getitem(state)
                 builder.append(box.raw_str())
-                iter.next()
+                state = iter.next(state)
         else:
             builder.append_charpsize(self.implementation.get_storage(), self.implementation.get_storage_size())
 

pypy/module/micronumpy/nditer.py

 
     def __init__(self, it, op_flags):
         self.it = it
+        self.st = it.reset()
         self.op_flags = op_flags
 
     def done(self):
-        return self.it.done()
+        return self.it.done(self.st)
 
     def next(self):
-        self.it.next()
+        self.st = self.it.next(self.st)
 
     def getitem(self, space, array):
-        return self.op_flags.get_it_item[self.index](space, array, self.it)
+        return self.op_flags.get_it_item[self.index](space, array, self.it, self.st)
 
     def setitem(self, space, array, val):
         xxx
         self.get_it_item = (get_readonly_item, get_readonly_slice)
 
 
-def get_readonly_item(space, array, it):
-    return space.wrap(it.getitem())
+def get_readonly_item(space, array, it, st):
+    return space.wrap(it.getitem(st))
 
 
-def get_readwrite_item(space, array, it):
+def get_readwrite_item(space, array, it, st):
     #create a single-value view (since scalars are not views)
-    res = SliceArray(it.array.start + it.offset, [0], [0], [1], it.array, array)
+    res = SliceArray(it.array.start + st.offset, [0], [0], [1], it.array, array)
     #it.dtype.setitem(res, 0, it.getitem())
     return W_NDimArray(res)
 

pypy/module/micronumpy/sort.py

             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             index_impl = index_arr.implementation
             index_iter = AllButAxisIter(index_impl, axis)
+            index_state = index_iter.reset()
             stride_size = arr.strides[axis]
             index_stride_size = index_impl.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
+            while not arr_iter.done(arr_state):
                 for i in range(axis_size):
                     raw_storage_setitem(storage, i * index_stride_size +
-                                        index_iter.offset, i)
+                                        index_state.offset, i)
                 r = Repr(index_stride_size, stride_size, axis_size,
-                         arr.get_storage(), storage, index_iter.offset, arr_iter.offset)
+                         arr.get_storage(), storage, index_state.offset, arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
-                index_iter.next()
+                arr_state = arr_iter.next(arr_state)
+                index_state = index_iter.next(index_state)
         return index_arr
 
     return argsort
             if axis < 0 or axis >= len(shape):
                 raise oefmt(space.w_IndexError, "Wrong axis %d", axis)
             arr_iter = AllButAxisIter(arr, axis)
+            arr_state = arr_iter.reset()
             stride_size = arr.strides[axis]
             axis_size = arr.shape[axis]
-            while not arr_iter.done():
-                r = Repr(stride_size, axis_size, arr.get_storage(), arr_iter.offset)
+            while not arr_iter.done(arr_state):
+                r = Repr(stride_size, axis_size, arr.get_storage(), arr_state.offset)
                 ArgSort(r).sort()
-                arr_iter.next()
+                arr_state = arr_iter.next(arr_state)
 
     return sort
 

pypy/module/micronumpy/test/test_iterators.py

         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 3
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 3
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 5
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 5
+        assert s.indices == [1,0]
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next()
-        i.next()
-        i.next()
-        assert i.offset == 9
-        assert not i.done()
-        assert i.indices == [0,3]
+        s = i.reset()
+        s = i.next(s)
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 9
+        assert not i.done(s)
+        assert s.indices == [0,3]
         #cause a dimension overflow
-        i.next()
-        i.next()
-        assert i.offset == 1
-        assert i.indices == [1,0]
+        s = i.next(s)
+        s = i.next(s)
+        assert s.offset == 1
+        assert s.indices == [1,0]
 
     def test_iterator_step(self):
         #iteration in C order with #contiguous layout => strides[-1] is 1
         assert backstrides == [10, 4]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 6
-        assert not i.done()
-        assert i.indices == [1,1]
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 6
+        assert not i.done(s)
+        assert s.indices == [1,1]
         #And for some big skips
-        i.next_skip_x(5)
-        assert i.offset == 11
-        assert i.indices == [2,1]
-        i.next_skip_x(5)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 11
+        assert s.indices == [2,1]
+        s = i.next_skip_x(s, 5)
         # Note: the offset does not overflow but recycles,
         # this is good for broadcast
-        assert i.offset == 1
-        assert i.indices == [0,1]
-        assert i.done()
+        assert s.offset == 1
+        assert s.indices == [0,1]
+        assert i.done(s)
 
         #Now what happens if the array is transposed? strides[-1] != 1
         # therefore layout is non-contiguous
         assert backstrides == [2, 12]
         i = ArrayIter(MockArray, support.product(shape), shape,
                       strides, backstrides)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        i.next_skip_x(2)
-        assert i.offset == 4
-        assert i.indices == [1,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.offset == 5
-        assert i.indices == [2,1]
-        assert not i.done()
-        i.next_skip_x(5)
-        assert i.indices == [0,1]
-        assert i.offset == 3
-        assert i.done()
+        s = i.reset()
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        s = i.next_skip_x(s, 2)
+        assert s.offset == 4
+        assert s.indices == [1,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.offset == 5
+        assert s.indices == [2,1]
+        assert not i.done(s)
+        s = i.next_skip_x(s, 5)
+        assert s.indices == [0,1]
+        assert s.offset == 3
+        assert i.done(s)

pypy/module/micronumpy/test/test_zjit.py

                 raise Exception("need results")
             w_res = interp.results[-1]
             if isinstance(w_res, W_NDimArray):
-                w_res = w_res.create_iter().getitem()
+                i, s = w_res.create_iter()
+                w_res = i.getitem(s)
             if isinstance(w_res, boxes.W_Float64Box):
                 return w_res.value
             if isinstance(w_res, boxes.W_Int64Box):
         self.check_simple_loop({
             'float_add': 1,
             'getarrayitem_gc': 3,
-            'getfield_gc': 7,
             'guard_false': 1,
             'guard_not_invalidated': 1,
             'guard_true': 3,
             'raw_load': 2,
             'raw_store': 1,
             'setarrayitem_gc': 3,
-            'setfield_gc': 6,
         })
 
     def define_pow():
             'float_mul': 2,
             'float_ne': 1,
             'getarrayitem_gc': 3,
-            'getfield_gc': 7,
             'guard_false': 4,
             'guard_not_invalidated': 1,
             'guard_true': 5,
             'raw_load': 2,
             'raw_store': 1,
             'setarrayitem_gc': 3,
-            'setfield_gc': 6,
         })
 
     def define_sum():
         self.check_trace_count(1)
         self.check_simple_loop({
             'getarrayitem_gc': 2,
-            'getfield_gc': 4,
             'guard_not_invalidated': 1,
             'guard_true': 3,
             'int_add': 6,
             'raw_load': 1,
             'raw_store': 1,
             'setarrayitem_gc': 2,
-            'setfield_gc': 4,
         })
 
     def define_dot():
         return """
         a = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]
-        b=[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
+        b = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]]
         c = dot(a, b)
         c -> 1 -> 2
         """
             'raw_load': 2,
         })
         self.check_resops({
+            'arraylen_gc': 1,
             'float_add': 2,
             'float_mul': 2,
             'getarrayitem_gc': 7,
             'getarrayitem_gc_pure': 15,
-            'getfield_gc': 35,
-            'getfield_gc_pure': 39,
+            'getfield_gc': 8,
+            'getfield_gc_pure': 44,
             'guard_class': 4,
             'guard_false': 14,
-            'guard_nonnull': 12,
-            'guard_nonnull_class': 4,
             'guard_not_invalidated': 2,
             'guard_true': 13,
-            'guard_value': 4,
             'int_add': 25,
             'int_ge': 4,
             'int_le': 8,
             'int_lt': 11,
             'int_sub': 4,
             'jump': 3,
+            'new_array': 1,
+            'new_with_vtable': 7,
             'raw_load': 6,
             'raw_store': 1,
-            'setarrayitem_gc': 10,
-            'setfield_gc': 14,
+            'same_as': 2,
+            'setarrayitem_gc': 8,
+            'setfield_gc': 16,
         })
 
     def define_argsort():
     def test_argsort(self):
         result = self.run("argsort")
         assert result == 6
+
+    def define_where():
+        return """
+        a = [1, 0, 1, 0]
+        x = [1, 2, 3, 4]
+        y = [-10, -20, -30, -40]
+        r = where(a, x, y)
+        r -> 3
+        """
+
+    def test_where(self):
+        result = self.run("where")
+        assert result == -40
+        self.check_trace_count(1)
+        self.check_simple_loop({
+            'float_ne': 1,
+            'getarrayitem_gc': 4,
+            'guard_false': 1,
+            'guard_not_invalidated': 1,
+            'guard_true': 5,
+            'int_add': 12,
+            'int_ge': 1,
+            'int_lt': 4,
+            'jump': 1,
+            'raw_load': 2,
+            'raw_store': 1,
+            'setarrayitem_gc': 4,
+        })