Brian Kearns avatar Brian Kearns committed 912fe41

fix promote_to_largest in reduce operations (fixes issue1663)

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

         return func_with_new_name(impl, "reduce_%s_impl_%d_%d" % (ufunc_name,
                     promote_to_largest, cumulative))
 
-    descr_sum = _reduce_ufunc_impl("add")
-    descr_sum_promote = _reduce_ufunc_impl("add", True)
+    descr_sum = _reduce_ufunc_impl("add", True)
     descr_prod = _reduce_ufunc_impl("multiply", True)
     descr_max = _reduce_ufunc_impl("maximum")
     descr_min = _reduce_ufunc_impl("minimum")

pypy/module/micronumpy/interp_ufuncs.py

 @jit.unroll_safe
 def find_unaryop_result_dtype(space, dt, promote_to_float=False,
         promote_bools=False, promote_to_largest=False):
+    if promote_to_largest:
+        if dt.kind == NPY_GENBOOLLTR or dt.kind == NPY_SIGNEDLTR:
+            return interp_dtype.get_dtype_cache(space).w_int64dtype
+        elif dt.kind == NPY_UNSIGNEDLTR:
+            return interp_dtype.get_dtype_cache(space).w_uint64dtype
+        elif dt.kind == NPY_FLOATINGLTR or dt.kind == NPY_COMPLEXLTR:
+            return dt
+        else:
+            assert False
     if promote_bools and (dt.kind == NPY_GENBOOLLTR):
         return interp_dtype.get_dtype_cache(space).w_int8dtype
     if promote_to_float:
             if (dtype.kind == NPY_FLOATINGLTR and
                 dtype.itemtype.get_element_size() > dt.itemtype.get_element_size()):
                 return dtype
-    if promote_to_largest:
-        if dt.kind == NPY_GENBOOLLTR or dt.kind == NPY_SIGNEDLTR:
-            return interp_dtype.get_dtype_cache(space).w_float64dtype
-        elif dt.kind == NPY_FLOATINGLTR:
-            return interp_dtype.get_dtype_cache(space).w_float64dtype
-        elif dt.kind == NPY_UNSIGNEDLTR:
-            return interp_dtype.get_dtype_cache(space).w_uint64dtype
-        else:
-            assert False
     return dt
 
 def find_dtype_for_scalar(space, w_obj, current_guess=None):

pypy/module/micronumpy/test/test_numarray.py

         assert d[1] == 12
 
     def test_sum(self):
-        from numpypy import array, zeros
+        from numpypy import array, zeros, float16, complex64, str_
         a = array(range(5))
         assert a.sum() == 10
         assert a[:4].sum() == 6
         a = array([True] * 5, bool)
         assert a.sum() == 5
 
+        assert array([True, False] * 200).sum() == 200
+        assert array([True, False] * 200, dtype='int8').sum() == 200
+        assert array([True, False] * 200).sum(dtype='int8') == -56
+        assert type(array([True, False] * 200, dtype='float16').sum()) is float16
+        assert type(array([True, False] * 200, dtype='complex64').sum()) is complex64
+
         raises(TypeError, 'a.sum(axis=0, out=3)')
         raises(ValueError, 'a.sum(axis=2)')
         d = array(0.)
         assert (array([[1,2],[3,4]]).prod(1) == [2, 12]).all()
 
     def test_prod(self):
-        from numpypy import array
+        from numpypy import array, int_, dtype
         a = array(range(1, 6))
         assert a.prod() == 120.0
         assert a[:4].prod() == 24.0
+        a = array([True, False])
+        assert a.prod() == 0
+        assert type(a.prod()) is int_
+        a = array([True, False], dtype='uint')
+        assert a.prod() == 0
+        assert type(a.prod()) is dtype('uint').type
 
     def test_max(self):
         from numpypy import array, zeros
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.