Commits

Maciej Fijalkowski  committed 971b8a2 Merge

Merge numpy-reintroduce-jit-drivers

This branch reintroduces jit drivers to numpy operations, however without
the lazy evaluation.

  • Participants
  • Parent commits 385b313, 44330b0

Comments (0)

Files changed (7)

File pypy/module/micronumpy/arrayimpl/concrete.py

     def setitem_index(self, space, index, value):
         self.setitem(self._lookup_by_unwrapped_index(space, index), value)
 
+    @jit.unroll_safe
     def _single_item_index(self, space, w_idx):
         """ Return an index of single item if possible, otherwise raises
         IndexError

File pypy/module/micronumpy/compile.py

 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy import interp_boxes
 from pypy.module.micronumpy.interp_dtype import get_dtype_cache
-from pypy.module.micronumpy.interp_numarray import (Scalar, BaseArray,
-     scalar_w, W_NDimArray, array)
+from pypy.module.micronumpy.base import W_NDimArray
+from pypy.module.micronumpy.interp_numarray import array
 from pypy.module.micronumpy.interp_arrayops import where
 from pypy.module.micronumpy import interp_ufuncs
 from pypy.rlib.objectmodel import specialize, instantiate
 
     def is_true(self, w_obj):
         assert isinstance(w_obj, BoolObject)
-        return w_obj.boolval
+        return False
+        #return w_obj.boolval
 
     def is_w(self, w_obj, w_what):
         return w_obj is w_what
         if isinstance(w_index, FloatObject):
             w_index = IntObject(int(w_index.floatval))
         w_val = self.expr.execute(interp)
-        assert isinstance(arr, BaseArray)
+        assert isinstance(arr, W_NDimArray)
         arr.descr_setitem(interp.space, w_index, w_val)
 
     def __repr__(self):
             w_rhs = self.rhs.wrap(interp.space)
         else:
             w_rhs = self.rhs.execute(interp)
-        if not isinstance(w_lhs, BaseArray):
+        if not isinstance(w_lhs, W_NDimArray):
             # scalar
             dtype = get_dtype_cache(interp.space).w_float64dtype
-            w_lhs = scalar_w(interp.space, dtype, w_lhs)
-        assert isinstance(w_lhs, BaseArray)
+            w_lhs = W_NDimArray.new_scalar(interp.space, dtype, w_lhs)
+        assert isinstance(w_lhs, W_NDimArray)
         if self.name == '+':
             w_res = w_lhs.descr_add(interp.space, w_rhs)
         elif self.name == '*':
         elif self.name == '-':
             w_res = w_lhs.descr_sub(interp.space, w_rhs)
         elif self.name == '->':
-            assert not isinstance(w_rhs, Scalar)
             if isinstance(w_rhs, FloatObject):
                 w_rhs = IntObject(int(w_rhs.floatval))
-            assert isinstance(w_lhs, BaseArray)
+            assert isinstance(w_lhs, W_NDimArray)
             w_res = w_lhs.descr_getitem(interp.space, w_rhs)
         else:
             raise NotImplementedError
-        if (not isinstance(w_res, BaseArray) and
+        if (not isinstance(w_res, W_NDimArray) and
             not isinstance(w_res, interp_boxes.W_GenericBox)):
             dtype = get_dtype_cache(interp.space).w_float64dtype
-            w_res = scalar_w(interp.space, dtype, w_res)
+            w_res = W_NDimArray.new_scalar(interp.space, dtype, w_res)
         return w_res
 
     def __repr__(self):
 
     def execute(self, interp):
         arr = self.args[0].execute(interp)
-        if not isinstance(arr, BaseArray):
+        if not isinstance(arr, W_NDimArray):
             raise ArgumentNotAnArray
         if self.name in SINGLE_ARG_FUNCTIONS:
             if len(self.args) != 1 and self.name != 'sum':
             elif self.name == "unegative":
                 neg = interp_ufuncs.get(interp.space).negative
                 w_res = neg.call(interp.space, [arr])
+            elif self.name == "cos":
+                cos = interp_ufuncs.get(interp.space).cos
+                w_res = cos.call(interp.space, [arr])                
             elif self.name == "flat":
                 w_res = arr.descr_get_flatiter(interp.space)
             elif self.name == "tostring":
                 arr.descr_tostring(interp.space)
                 w_res = None
-            elif self.name == "count_nonzero":
-                w_res = arr.descr_count_nonzero(interp.space)
             else:
                 assert False # unreachable code
         elif self.name in TWO_ARG_FUNCTIONS:
             if len(self.args) != 2:
                 raise ArgumentMismatch
             arg = self.args[1].execute(interp)
-            if not isinstance(arg, BaseArray):
+            if not isinstance(arg, W_NDimArray):
                 raise ArgumentNotAnArray
             if self.name == "dot":
                 w_res = arr.descr_dot(interp.space, arg)
                 raise ArgumentMismatch
             arg1 = self.args[1].execute(interp)
             arg2 = self.args[2].execute(interp)
-            if not isinstance(arg1, BaseArray):
+            if not isinstance(arg1, W_NDimArray):
                 raise ArgumentNotAnArray
-            if not isinstance(arg2, BaseArray):
+            if not isinstance(arg2, W_NDimArray):
                 raise ArgumentNotAnArray
             if self.name == "where":
                 w_res = where(interp.space, arr, arg1, arg2)
                 assert False
         else:
             raise WrongFunctionName
-        if isinstance(w_res, BaseArray):
+        if isinstance(w_res, W_NDimArray):
             return w_res
         if isinstance(w_res, FloatObject):
             dtype = get_dtype_cache(interp.space).w_float64dtype
             dtype = w_res.get_dtype(interp.space)
         else:
             dtype = None
-        return scalar_w(interp.space, dtype, w_res)
+        return W_NDimArray.new_scalar(interp.space, dtype, w_res)
 
 _REGEXES = [
     ('-?[\d\.]+', 'number'),

File pypy/module/micronumpy/interp_numarray.py

             if self.get_size() == 0:
                 raise OperationError(space.w_ValueError,
                     space.wrap("Can't call %s on zero-size arrays" % op_name))
-            return space.wrap(loop.argmin_argmax(op_name, self))
+            return space.wrap(getattr(loop, 'arg' + op_name)(self))
         return func_with_new_name(impl, "reduce_arg%s_impl" % op_name)
 
     descr_argmax = _reduce_argmax_argmin_impl("max")

File pypy/module/micronumpy/interp_ufuncs.py

             return out
         shape = shape_agreement(space, w_obj.get_shape(), out,
                                 broadcast_down=False)
-        return loop.call1(shape, self.func, self.name, calc_dtype, res_dtype,
+        return loop.call1(shape, self.func, calc_dtype, res_dtype,
                           w_obj, out)
 
 
             return out
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
-        return loop.call2(new_shape, self.func, self.name, calc_dtype,
+        return loop.call2(new_shape, self.func, calc_dtype,
                           res_dtype, w_lhs, w_rhs, out)
 
 

File pypy/module/micronumpy/loop.py

 
 """ This file is the main run loop as well as evaluation loops for various
-signatures
+operations. This is the place to look for all the computations that iterate
+over all the array elements.
 """
 
-from pypy.rlib.objectmodel import specialize
 from pypy.rlib.rstring import StringBuilder
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
 
-def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
+call2_driver = jit.JitDriver(name='numpy_call2',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
+                                     'left_iter', 'right_iter', 'out_iter'])
+
+def call2(shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     if out is None:
         out = W_NDimArray.from_shape(shape, res_dtype)
     left_iter = w_lhs.create_iter(shape)
     right_iter = w_rhs.create_iter(shape)
     out_iter = out.create_iter(shape)
+    shapelen = len(shape)
     while not out_iter.done():
+        call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
+                                     shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
+                                     out=out,
+                                     left_iter=left_iter, right_iter=right_iter,
+                                     out_iter=out_iter)
         w_left = left_iter.getitem().convert_to(calc_dtype)
         w_right = right_iter.getitem().convert_to(calc_dtype)
         out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
         out_iter.next()
     return out
 
-def call1(shape, func, name, calc_dtype, res_dtype, w_obj, out):
+call1_driver = jit.JitDriver(name='numpy_call1',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_obj', 'out', 'obj_iter',
+                                     'out_iter'])
+
+def call1(shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(shape, res_dtype)
     obj_iter = w_obj.create_iter(shape)
     out_iter = out.create_iter(shape)
+    shapelen = len(shape)
     while not out_iter.done():
+        call1_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
+                                     shape=shape, w_obj=w_obj, out=out,
+                                     obj_iter=obj_iter, out_iter=out_iter)
         elem = obj_iter.getitem().convert_to(calc_dtype)
         out_iter.setitem(func(calc_dtype, elem).convert_to(res_dtype))
         out_iter.next()
         obj_iter.next()
     return out
 
+setslice_driver = jit.JitDriver(name='numpy_setslice',
+                                greens = ['shapelen', 'dtype'],
+                                reds = ['target', 'source', 'target_iter',
+                                        'source_iter'])
+
 def setslice(shape, target, source):
     # note that unlike everything else, target and source here are
     # array implementations, not arrays
     target_iter = target.create_iter(shape)
     source_iter = source.create_iter(shape)
     dtype = target.dtype
+    shapelen = len(shape)
     while not target_iter.done():
+        setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+                                        target=target, source=source,
+                                        target_iter=target_iter,
+                                        source_iter=source_iter)
         target_iter.setitem(source_iter.getitem().convert_to(dtype))
         target_iter.next()
         source_iter.next()
     return target
 
+reduce_driver = jit.JitDriver(name='numpy_reduce',
+                              greens = ['shapelen', 'func', 'done_func',
+                                        'calc_dtype', 'identity'],
+                              reds = ['obj', 'obj_iter', 'cur_value'])
+
 def compute_reduce(obj, calc_dtype, func, done_func, identity):
     obj_iter = obj.create_iter(obj.get_shape())
     if identity is None:
         obj_iter.next()
     else:
         cur_value = identity.convert_to(calc_dtype)
+    shapelen = len(obj.get_shape())
     while not obj_iter.done():
+        reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                      calc_dtype=calc_dtype, identity=identity,
+                                      done_func=done_func, obj=obj,
+                                      obj_iter=obj_iter, cur_value=cur_value)
         rval = obj_iter.getitem().convert_to(calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
         arr_iter.setitem(box)
         arr_iter.next()
 
+where_driver = jit.JitDriver(name='numpy_where',
+                             greens = ['shapelen', 'dtype', 'arr_dtype'],
+                             reds = ['shape', 'arr', 'x', 'y','arr_iter', 'out',
+                                     'x_iter', 'y_iter', 'iter', 'out_iter'])
+
 def where(out, shape, arr, x, y, dtype):
     out_iter = out.create_iter(shape)
     arr_iter = arr.create_iter(shape)
             iter = y_iter
     else:
         iter = x_iter
+    shapelen = len(shape)
     while not iter.done():
+        where_driver.jit_merge_point(shapelen=shapelen, shape=shape,
+                                     dtype=dtype, iter=iter, x_iter=x_iter,
+                                     y_iter=y_iter, arr_iter=arr_iter,
+                                     arr=arr, x=x, y=y, arr_dtype=arr_dtype,
+                                     out_iter=out_iter, out=out)
         w_cond = arr_iter.getitem()
         if arr_dtype.itemtype.bool(w_cond):
             w_val = x_iter.getitem().convert_to(dtype)
         y_iter.next()
     return out
 
+axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
+                                    greens=['shapelen', 'func', 'dtype',
+                                            'identity'],
+                                    reds=['axis', 'arr', 'out', 'shape',
+                                          'out_iter', 'arr_iter'])
+
 def do_axis_reduce(shape, func, arr, dtype, axis, out, identity):
     out_iter = out.create_axis_iter(arr.get_shape(), axis)
     arr_iter = arr.create_iter(arr.get_shape())
     if identity is not None:
         identity = identity.convert_to(dtype)
+    shapelen = len(shape)
     while not out_iter.done():
+        axis_reduce__driver.jit_merge_point(shapelen=shapelen, func=func,
+                                            dtype=dtype, identity=identity,
+                                            axis=axis, arr=arr, out=out,
+                                            shape=shape, out_iter=out_iter,
+                                            arr_iter=arr_iter)
         w_val = arr_iter.getitem().convert_to(dtype)
         if out_iter.first_line:
             if identity is not None:
         out_iter.next()
     return out
 
-@specialize.arg(0)
-def argmin_argmax(op_name, arr):
-    result = 0
-    idx = 1
-    dtype = arr.get_dtype()
-    iter = arr.create_iter(arr.get_shape())
-    cur_best = iter.getitem()
-    iter.next()
-    while not iter.done():
-        w_val = iter.getitem()
-        new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
-        if dtype.itemtype.ne(new_best, cur_best):
-            result = idx
-            cur_best = new_best
+
+def _new_argmin_argmax(op_name):
+    arg_driver = jit.JitDriver(name='numpy_' + op_name,
+                               greens = ['shapelen', 'dtype'],
+                               reds = ['result', 'idx', 'cur_best', 'arr',
+                                       'iter'])
+    
+    def argmin_argmax(arr):
+        result = 0
+        idx = 1
+        dtype = arr.get_dtype()
+        iter = arr.create_iter(arr.get_shape())
+        cur_best = iter.getitem()
         iter.next()
-        idx += 1
-    return result
+        shapelen = len(arr.get_shape())
+        while not iter.done():
+            arg_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+                                       result=result, idx=idx,
+                                       cur_best=cur_best, arr=arr, iter=iter)
+            w_val = iter.getitem()
+            new_best = getattr(dtype.itemtype, op_name)(cur_best, w_val)
+            if dtype.itemtype.ne(new_best, cur_best):
+                result = idx
+                cur_best = new_best
+            iter.next()
+            idx += 1
+        return result
+    return argmin_argmax
+argmin = _new_argmin_argmax('min')
+argmax = _new_argmin_argmax('max')
+
+# note that shapelen == 2 always
+dot_driver = jit.JitDriver(name = 'numpy_dot',
+                           greens = ['dtype'],
+                           reds = ['outi', 'lefti', 'righti', 'result'])
 
 def multidim_dot(space, left, right, result, dtype, right_critical_dim):
     ''' assumes left, right are concrete arrays
     lefti = left.create_dot_iter(broadcast_shape, left_skip)
     righti = right.create_dot_iter(broadcast_shape, right_skip)
     while not outi.done():
+        dot_driver.jit_merge_point(dtype=dtype, outi=outi, lefti=lefti,
+                                   righti=righti, result=result)
         lval = lefti.getitem().convert_to(dtype) 
         rval = righti.getitem().convert_to(dtype) 
         outval = outi.getitem().convert_to(dtype) 
         lefti.next()
     return result
 
+count_all_true_driver = jit.JitDriver(name = 'numpy_count',
+                                      greens = ['shapelen', 'dtype'],
+                                      reds = ['s', 'iter'])
+
 def count_all_true(arr):
     s = 0
     if arr.is_scalar():
         return arr.get_dtype().itemtype.bool(arr.get_scalar_value())
     iter = arr.create_iter()
+    shapelen = len(arr.get_shape())
+    dtype = arr.get_dtype()
     while not iter.done():
+        count_all_true_driver.jit_merge_point(shapelen=shapelen, iter=iter,
+                                              s=s, dtype=dtype)
         s += iter.getitem_bool()
         iter.next()
     return s
 
+getitem_filter_driver = jit.JitDriver(name = 'numpy_getitem_bool',
+                                      greens = ['shapelen', 'arr_dtype',
+                                                'index_dtype'],
+                                      reds = ['res', 'index_iter', 'res_iter',
+                                              'arr_iter'])
+
 def getitem_filter(res, arr, index):
     res_iter = res.create_iter()
     index_iter = index.create_iter()
     arr_iter = arr.create_iter()
+    shapelen = len(arr.get_shape())
+    arr_dtype = arr.get_dtype()
+    index_dtype = index.get_dtype()
+    # XXX length of shape of index as well?
     while not index_iter.done():
+        getitem_filter_driver.jit_merge_point(shapelen=shapelen,
+                                              index_dtype=index_dtype,
+                                              arr_dtype=arr_dtype,
+                                              res=res, index_iter=index_iter,
+                                              res_iter=res_iter,
+                                              arr_iter=arr_iter)
         if index_iter.getitem_bool():
             res_iter.setitem(arr_iter.getitem())
             res_iter.next()
         arr_iter.next()
     return res
 
+setitem_filter_driver = jit.JitDriver(name = 'numpy_setitem_bool',
+                                      greens = ['shapelen', 'arr_dtype',
+                                                'index_dtype'],
+                                      reds = ['index_iter', 'value_iter',
+                                              'arr_iter'])
+
 def setitem_filter(arr, index, value):
     arr_iter = arr.create_iter()
     index_iter = index.create_iter()
     value_iter = value.create_iter()
+    shapelen = len(arr.get_shape())
+    index_dtype = index.get_dtype()
+    arr_dtype = arr.get_dtype()
     while not index_iter.done():
+        setitem_filter_driver.jit_merge_point(shapelen=shapelen,
+                                              index_dtype=index_dtype,
+                                              arr_dtype=arr_dtype,
+                                              index_iter=index_iter,
+                                              value_iter=value_iter,
+                                              arr_iter=arr_iter)
         if index_iter.getitem_bool():
             arr_iter.setitem(value_iter.getitem())
             value_iter.next()
         arr_iter.next()
         index_iter.next()
 
+flatiter_getitem_driver = jit.JitDriver(name = 'numpy_flatiter_getitem',
+                                        greens = ['dtype'],
+                                        reds = ['step', 'ri', 'res',
+                                                'base_iter'])
+
 def flatiter_getitem(res, base_iter, step):
     ri = res.create_iter()
+    dtype = res.get_dtype()
     while not ri.done():
+        flatiter_getitem_driver.jit_merge_point(dtype=dtype,
+                                                base_iter=base_iter,
+                                                ri=ri, res=res, step=step)
         ri.setitem(base_iter.getitem())
         base_iter.next_skip_x(step)
         ri.next()
     return res
 
+flatiter_setitem_driver = jit.JitDriver(name = 'numpy_flatiter_setitem',
+                                        greens = ['dtype'],
+                                        reds = ['length', 'step', 'arr_iter',
+                                                'val_iter'])
+
 def flatiter_setitem(arr, val, start, step, length):
     dtype = arr.get_dtype()
     arr_iter = arr.create_iter()
     val_iter = val.create_iter()
     arr_iter.next_skip_x(start)
     while length > 0:
+        flatiter_setitem_driver.jit_merge_point(dtype=dtype, length=length,
+                                                step=step, arr_iter=arr_iter,
+                                                val_iter=val_iter)
         arr_iter.setitem(val_iter.getitem().convert_to(dtype))
         # need to repeat i_nput values until all assignments are done
         arr_iter.next_skip_x(step)
         # WTF numpy?
         val_iter.reset()
 
+fromstring_driver = jit.JitDriver(name = 'numpy_fromstring',
+                                  greens = ['itemsize', 'dtype'],
+                                  reds = ['i', 's', 'ai'])
+
 def fromstring_loop(a, dtype, itemsize, s):
     i = 0
     ai = a.create_iter()
     while not ai.done():
+        fromstring_driver.jit_merge_point(dtype=dtype, s=s, ai=ai, i=i,
+                                          itemsize=itemsize)
         val = dtype.itemtype.runpack_str(s[i*itemsize:i*itemsize + itemsize])
         ai.setitem(val)
         ai.next()
     def get_index(self, space):
         return [space.wrap(i) for i in self.indexes]
 
+getitem_int_driver = jit.JitDriver(name = 'numpy_getitem_int',
+                                   greens = ['shapelen', 'indexlen', 'dtype'],
+                                   reds = ['arr', 'res', 'iter', 'indexes_w',
+                                           'prefix_w'])
+
 def getitem_array_int(space, arr, res, iter_shape, indexes_w, prefix_w):
+    shapelen = len(iter_shape)
+    indexlen = len(indexes_w)
+    dtype = arr.get_dtype()
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
+        getitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen,
+                                           dtype=dtype, arr=arr, res=res,
+                                           iter=iter, indexes_w=indexes_w,
+                                           prefix_w=prefix_w)
         # prepare the index
         index_w = [None] * len(indexes_w)
         for i in range(len(indexes_w)):
         iter.next()
     return res
 
+setitem_int_driver = jit.JitDriver(name = 'numpy_setitem_int',
+                                   greens = ['shapelen', 'indexlen', 'dtype'],
+                                   reds = ['arr', 'iter', 'indexes_w',
+                                           'prefix_w', 'val_arr'])
+
 def setitem_array_int(space, arr, iter_shape, indexes_w, val_arr,
                       prefix_w):
+    shapelen = len(iter_shape)
+    indexlen = len(indexes_w)
+    dtype = arr.get_dtype()
     iter = PureShapeIterator(iter_shape, indexes_w)
     while not iter.done():
+        setitem_int_driver.jit_merge_point(shapelen=shapelen, indexlen=indexlen,
+                                           dtype=dtype, arr=arr,
+                                           iter=iter, indexes_w=indexes_w,
+                                           prefix_w=prefix_w, val_arr=val_arr)
         # prepare the index
         index_w = [None] * len(indexes_w)
         for i in range(len(indexes_w)):

File pypy/module/micronumpy/test/test_compile.py

+
 import py
-py.test.skip("this is going away")
-
 from pypy.module.micronumpy.compile import (numpy_compile, Assignment,
     ArrayConstant, FloatConstant, Operator, Variable, RangeConstant, Execute,
     FunctionCall, FakeSpace)
         r
         """
         interp = self.run(code)
-        assert interp.results[0].value.value == 15
+        assert interp.results[0].get_scalar_value().value == 15
 
     def test_sum2(self):
         code = """
         sum(b)
         """
         interp = self.run(code)
-        assert interp.results[0].value.value == 30 * (30 - 1)
+        assert interp.results[0].get_scalar_value().value == 30 * (30 - 1)
 
 
     def test_array_write(self):
         b = a + a
         min(b)
         """)
-        assert interp.results[0].value.value == -24
+        assert interp.results[0].get_scalar_value().value == -24
 
     def test_max(self):
         interp = self.run("""
         b = a + a
         max(b)
         """)
-        assert interp.results[0].value.value == 256
+        assert interp.results[0].get_scalar_value().value == 256
 
     def test_slice(self):
         interp = self.run("""
         assert interp.results[0].value == 3
 
     def test_take(self):
+        py.test.skip("unsupported")
         interp = self.run("""
         a = |10|
         b = take(a, [1, 1, 3, 2])

File pypy/module/micronumpy/test/test_zjit.py

 """
 
 import py
-py.test.skip("this is going away")
-
 from pypy.jit.metainterp import pyjitpl
 from pypy.jit.metainterp.test.support import LLJitMixin
 from pypy.jit.metainterp.warmspot import reset_stats
 from pypy.module.micronumpy import interp_boxes
-from pypy.module.micronumpy.compile import (FakeSpace,
-    IntObject, Parser, InterpreterState)
-from pypy.module.micronumpy.interp_numarray import (W_NDimArray,
-     BaseArray, W_FlatIterator)
-from pypy.rlib.nonconst import NonConstant
-
+from pypy.module.micronumpy.compile import FakeSpace, Parser, InterpreterState
+from pypy.module.micronumpy.base import W_NDimArray
 
 class TestNumpyJIt(LLJitMixin):
     graph = None
             if not len(interp.results):
                 raise Exception("need results")
             w_res = interp.results[-1]
-            if isinstance(w_res, BaseArray):
-                concr = w_res.get_concrete_or_scalar()
-                sig = concr.find_sig()
-                frame = sig.create_frame(concr)
-                w_res = sig.eval(frame, concr)
+            if isinstance(w_res, W_NDimArray):
+                w_res = w_res.create_iter().getitem()
             if isinstance(w_res, interp_boxes.W_Float64Box):
                 return w_res.value
             if isinstance(w_res, interp_boxes.W_Int64Box):
             self.__class__.graph = graph
         reset_stats()
         pyjitpl._warmrunnerdesc.memory_manager.alive_loops.clear()
+        py.test.skip("don't run for now")
         return self.interp.eval_graph(self.graph, [i])
 
     def define_add():