Maciej Fijalkowski avatar Maciej Fijalkowski committed c0b1fcc

start adding jitdrivers

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_ufuncs.py

             return out
         shape = shape_agreement(space, w_obj.get_shape(), out,
                                 broadcast_down=False)
-        return loop.call1(shape, self.func, self.name, calc_dtype, res_dtype,
+        return loop.call1(shape, self.func, calc_dtype, res_dtype,
                           w_obj, out)
 
 
             return out
         new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
         new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
-        return loop.call2(new_shape, self.func, self.name, calc_dtype,
+        return loop.call2(new_shape, self.func, calc_dtype,
                           res_dtype, w_lhs, w_rhs, out)
 
 

pypy/module/micronumpy/loop.py

 
 """ This file is the main run loop as well as evaluation loops for various
-signatures
+operations. This is the place to look for all the computations that iterate
+over all the array elements.
 """
 
 from pypy.rlib.objectmodel import specialize
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
 
-def call2(shape, func, name, calc_dtype, res_dtype, w_lhs, w_rhs, out):
+call2_driver = jit.JitDriver(name='numpy_call2',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_lhs', 'w_rhs', 'out',
+                                     'left_iter', 'right_iter', 'out_iter'])
+
+def call2(shape, func, calc_dtype, res_dtype, w_lhs, w_rhs, out):
     if out is None:
         out = W_NDimArray.from_shape(shape, res_dtype)
     left_iter = w_lhs.create_iter(shape)
     right_iter = w_rhs.create_iter(shape)
     out_iter = out.create_iter(shape)
+    shapelen = len(shape)
     while not out_iter.done():
+        call2_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
+                                     shape=shape, w_lhs=w_lhs, w_rhs=w_rhs,
+                                     out=out,
+                                     left_iter=left_iter, right_iter=right_iter,
+                                     out_iter=out_iter)
         w_left = left_iter.getitem().convert_to(calc_dtype)
         w_right = right_iter.getitem().convert_to(calc_dtype)
         out_iter.setitem(func(calc_dtype, w_left, w_right).convert_to(
         out_iter.next()
     return out
 
-def call1(shape, func, name, calc_dtype, res_dtype, w_obj, out):
+call1_driver = jit.JitDriver(name='numpy_call1',
+                             greens = ['shapelen', 'func', 'calc_dtype',
+                                       'res_dtype'],
+                             reds = ['shape', 'w_obj', 'out', 'obj_iter',
+                                     'out_iter'])
+
+def call1(shape, func, calc_dtype, res_dtype, w_obj, out):
     if out is None:
         out = W_NDimArray.from_shape(shape, res_dtype)
     obj_iter = w_obj.create_iter(shape)
     out_iter = out.create_iter(shape)
+    shapelen = len(shape)
     while not out_iter.done():
+        call1_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                     calc_dtype=calc_dtype, res_dtype=res_dtype,
+                                     shape=shape, w_obj=w_obj, out=out,
+                                     obj_iter=obj_iter, out_iter=out_iter)
         elem = obj_iter.getitem().convert_to(calc_dtype)
         out_iter.setitem(func(calc_dtype, elem).convert_to(res_dtype))
         out_iter.next()
         obj_iter.next()
     return out
 
+setslice_driver = jit.JitDriver(name='numpy_setslice',
+                                greens = ['shapelen', 'dtype'],
+                                reds = ['target', 'source', 'target_iter',
+                                        'source_iter'])
+
 def setslice(shape, target, source):
     # note that unlike everything else, target and source here are
     # array implementations, not arrays
     target_iter = target.create_iter(shape)
     source_iter = source.create_iter(shape)
     dtype = target.dtype
+    shapelen = len(shape)
     while not target_iter.done():
+        setslice_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+                                        target=target, source=source,
+                                        target_iter=target_iter,
+                                        source_iter=source_iter)
         target_iter.setitem(source_iter.getitem().convert_to(dtype))
         target_iter.next()
         source_iter.next()
     return target
 
+reduce_driver = jit.JitDriver(name='numpy_reduce',
+                              greens = ['shapelen', 'func', 'calc_dtype',
+                                        'identity', 'done_func'],
+                              reds = ['obj', 'obj_iter', 'cur_value'])
+
 def compute_reduce(obj, calc_dtype, func, done_func, identity):
     obj_iter = obj.create_iter(obj.get_shape())
     if identity is None:
         obj_iter.next()
     else:
         cur_value = identity.convert_to(calc_dtype)
+    shapelen = len(obj.get_shape())
     while not obj_iter.done():
+        reduce_driver.jit_merge_point(shapelen=shapelen, func=func,
+                                      calc_dtype=calc_dtype, identity=identity,
+                                      done_func=done_func, obj=obj,
+                                      obj_iter=obj_iter, cur_value=cur_value)
         rval = obj_iter.getitem().convert_to(calc_dtype)
         if done_func is not None and done_func(calc_dtype, rval):
             return rval
         arr_iter.setitem(box)
         arr_iter.next()
 
+#where_driver = jit.JitDriver()
+
 def where(out, shape, arr, x, y, dtype):
     out_iter = out.create_iter(shape)
     arr_iter = arr.create_iter(shape)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.