Ilya Osadchiy avatar Ilya Osadchiy committed 949a502

Check size of arrays in binary operations (needed for comparison)

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_ufuncs.py

                 w_rhs.value.convert_to(calc_dtype)
             ).wrap(space)
 
+        w_lhs, w_rhs = _broadcast_arrays(space, w_lhs, w_rhs)
         new_sig = signature.Signature.find_sig([
             self.signature, w_lhs.signature, w_rhs.signature
         ])
     reduce = interp2app(W_Ufunc.descr_reduce),
 )
 
+def _broadcast_arrays(space, a1, a2):
+    from pypy.module.micronumpy.interp_numarray import Scalar
+    '''
+    Broadcast arrays to common size
+    '''
+    # For now just check sizes of two 1D arrays
+    if isinstance(a1, Scalar) or isinstance(a2, Scalar):
+        return a1, a2
+    s1 = a1.find_size()
+    s2 = a2.find_size()
+    if s1 != s2:
+        raise operationerrfmt(space.w_ValueError, "operands could not "
+            "be broadcast together with shapes (%d) (%d)", s1, s2)
+    return a1, a2
+
 def find_binop_result_dtype(space, dt1, dt2, promote_to_float=False,
     promote_bools=False):
     # dt1.num should be <= dt2.num

pypy/module/micronumpy/test/test_base.py

 
     def test_slice_signature(self, space):
         ar = SingleDimArray(10, dtype=space.fromcache(interp_dtype.W_Float64Dtype))
-        v1 = ar.descr_getitem(space, space.wrap(slice(1, 5, 1)))
+        v1 = ar.descr_getitem(space, space.wrap(slice(1, 3, 1)))
         v2 = ar.descr_getitem(space, space.wrap(slice(4, 6, 1)))
+        v3 = ar.descr_getitem(space, space.wrap(slice(3, 5, 1)))
         assert v1.signature is v2.signature
 
-        v3 = ar.descr_add(space, v1)
-        v4 = ar.descr_add(space, v2)
-        assert v3.signature is v4.signature
+        v4 = v1.descr_add(space, v2)
+        v5 = v1.descr_add(space, v3)
+        assert v4.signature is v5.signature
 
 class TestUfuncCoerscion(object):
     def test_binops(self, space):

pypy/module/micronumpy/test/test_numarray.py

         a = array(range(5))
         assert (a == None) is False
         assert (a != None) is True
-        # TODO: uncomment after size check is implemented
-        # b = array(range(2))
-        # assert (a == b) is False
-        # assert (a != b) is True
+        b = array(range(2))
+        assert (a == b) is False
+        assert (a != b) is True
 
 class AppTestSupport(object):
     def setup_class(cls):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.