Matti Picus avatar Matti Picus committed 03aaf57

start to implement round(), need to add it to types.py

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

         raise OperationError(space.w_NotImplementedError, space.wrap(
             "resize not implemented yet"))
 
-    def descr_round(self, space, w_decimals=0, w_out=None):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "round not implemented yet"))
+    @unwrap_spec(decimals=int)
+    def descr_round(self, space, decimals=0, w_out=None):
+        if space.is_none(w_out):
+            w_out = None
+        elif not isinstance(w_out, W_NDimArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                "return arrays must be of ArrayType"))
+        out = interp_dtype.dtype_agreement(space, [self], self.get_shape(),
+                                           w_out)
+        loop.round(space, self, self.get_shape(), decimals, out)
+        return out
 
     def descr_searchsorted(self, space, w_v, w_side='left'):
         raise OperationError(space.w_NotImplementedError, space.wrap(
     byteswap = interp2app(W_NDimArray.descr_byteswap),
     choose   = interp2app(W_NDimArray.descr_choose),
     clip     = interp2app(W_NDimArray.descr_clip),
+    round    = interp2app(W_NDimArray.descr_round),
     data     = GetSetProperty(W_NDimArray.descr_get_data),
     diagonal = interp2app(W_NDimArray.descr_diagonal),
 

pypy/module/micronumpy/loop.py

 from pypy.module.micronumpy.iter import PureShapeIterator
 from pypy.module.micronumpy import constants
 from pypy.module.micronumpy.support import int_w
+from rpython.rlib.rfloat import round_double
 
 call2_driver = jit.JitDriver(name='numpy_call2',
                              greens = ['shapelen', 'func', 'calc_dtype',
         iter = x_iter
     shapelen = len(shape)
     while not iter.done():
-        where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype, 
+        where_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
                                         arr_dtype=arr_dtype)
         w_cond = arr_iter.getitem()
         if arr_dtype.itemtype.bool(w_cond):
     return out
 
 axis_reduce__driver = jit.JitDriver(name='numpy_axis_reduce',
-                                    greens=['shapelen', 
+                                    greens=['shapelen',
                                             'func', 'dtype',
                                             'identity'],
                                     reds='auto')
     arg_driver = jit.JitDriver(name='numpy_' + op_name,
                                greens = ['shapelen', 'dtype'],
                                reds = 'auto')
-    
+
     def argmin_argmax(arr):
         result = 0
         idx = 1
      result.shape == [3, 5, 2, 4]
      broadcast shape should be [3, 5, 2, 7, 4]
      result should skip dims 3 which is len(result_shape) - 1
-        (note that if right is 1d, result should 
+        (note that if right is 1d, result should
                   skip len(result_shape))
      left should skip 2, 4 which is a.ndims-1 + range(right.ndims)
           except where it==(right.ndims-2)
     righti = right.create_dot_iter(broadcast_shape, right_skip)
     while not outi.done():
         dot_driver.jit_merge_point(dtype=dtype)
-        lval = lefti.getitem().convert_to(dtype) 
-        rval = righti.getitem().convert_to(dtype) 
-        outval = outi.getitem().convert_to(dtype) 
+        lval = lefti.getitem().convert_to(dtype)
+        rval = righti.getitem().convert_to(dtype)
+        outval = outi.getitem().convert_to(dtype)
         v = dtype.itemtype.mul(lval, rval)
         value = dtype.itemtype.add(v, outval).convert_to(dtype)
         outi.setitem(value)
         setitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,
-                                             ) 
+                                             )
         if index_iter.getitem_bool():
             arr_iter.setitem(value_iter.getitem())
             value_iter.next()
         out_iter.next()
         min_iter.next()
 
+round_driver = jit.JitDriver(greens = ['shapelen', 'dtype'],
+                                    reds = 'auto')
+
+def round(space, arr, shape, decimals, out):
+    arr_iter = arr.create_iter(shape)
+    dtype = out.get_dtype()
+    shapelen = len(shape)
+    out_iter = out.create_iter(shape)
+    while not arr_iter.done():
+        round_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
+        #w_v = dtype.itemtype.round(arr_iter.getitem().convert_to(dtype),
+        #             decimals)
+        w_v = arr_iter.getitem().convert_to(dtype)
+        out_iter.setitem(w_v)
+        arr_iter.next()
+        out_iter.next()
+
 diagonal_simple_driver = jit.JitDriver(greens = ['axis1', 'axis2'],
                                        reds = 'auto')
 
         out_iter.setitem(arr.getitem_index(space, indexes))
         iter.next()
         out_iter.next()
-       
+

pypy/module/micronumpy/test/test_ufuncs.py

         assert all([math.copysign(1, f(abs(float("nan")))) == 1 for f in floor, ceil, trunc])
         assert all([math.copysign(1, f(-abs(float("nan")))) == -1 for f in floor, ceil, trunc])
 
-    def test_around(self):
+    def test_round(self):
         from numpypy import array
         ninf, inf = float("-inf"), float("inf")
         a = array([ninf, -1.4, -1.5, -1.0, 0.0, 1.0, 1.4, 0.5, inf])
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.