Daniel Roberts avatar Daniel Roberts committed d534f73

Added support code for array broadcasting, along with a small (unfinished) test.

Comments (0)

Files changed (2)

pypy/module/micronumpy/array.py

     else:
         return 0
 
+def broadcast_shapes(a_shape, a_strides, b_shape, b_strides):
+    a_dim = len(a_shape)
+    b_dim = len(b_shape)
+
+    smaller_dim = a_dim if a_dim < b_dim else b_dim
+
+    if a_dim > b_dim:
+        result = a_shape
+        larger_dim = a_dim
+        smaller_dim = b_dim
+        shorter_strides = b_strides
+    else:
+        result = b_shape
+        larger_dim = b_dim
+        smaller_dim = a_dim
+        shorter_strides = a_strides
+
+    i_a = a_dim - 1
+    i_b = b_dim - 1
+    for i in range(smaller_dim):
+        assert i_a >= 0
+        a = a_shape[i_a]
+
+        assert i_b >= 0
+        b = b_shape[i_b]
+
+        if a == b or a == 1 or b == 1:
+            i_a -= 1
+            i_b -= 1
+            result[len(result) - 1 - i] = a if a > b else b
+        else:
+            raise ValueError("frames are not aligned") # FIXME: applevel?
+    
+    if a_dim < b_dim:
+        i_b += 1
+        a_strides = [0] * i_b + a_strides
+    else:
+        i_a += 1
+        b_strides = [0] * i_a + b_strides
+    return result, a_strides, b_strides
+
 def normalize_slice_starts(slice_starts, shape):
     for i in range(len(slice_starts)):
         if slice_starts[i] < 0:

pypy/module/micronumpy/test/test_numpy.py

         for w_xs, typecode in data:
             assert typecode == infer_from_iterable(space, w_xs).typecode
 
+class TestArraySupport(object):
+    def test_broadcast_shapes(self, space):
+        from pypy.module.micronumpy.array import broadcast_shapes
+        from pypy.module.micronumpy.array import stride_row as stride
+
+        def test(shape_a, shape_b, expected_result, expected_strides_a=None, expected_strides_b=None):
+            strides_a = [stride(shape_a, i) for i, x in enumerate(shape_a)]
+            strides_a_save = strides_a[:]
+
+            strides_b = [stride(shape_b, i) for i, x in enumerate(shape_b)]
+            strides_b_save = strides_b[:]
+
+            result_shape, result_strides_a, result_strides_b = broadcast_shapes(shape_a, strides_a, shape_b, strides_b)
+            assert result_shape == expected_result
+
+            if expected_strides_a:
+                assert result_strides_a == expected_strides_a
+            else:
+                assert result_strides_a == strides_a_save
+
+            if expected_strides_b:
+                assert result_strides_b == expected_strides_b
+            else:
+                assert result_strides_b == strides_b_save
+
+        shape_a = [256, 256, 3]
+        shape_b = [3]
+
+        test([256, 256, 3], [3],
+             expected_result=[256, 256, 3],
+             expected_strides_b=[0, 0, 1])
+
+        test([3], [256, 256, 3],
+             expected_result=[256, 256, 3],
+             expected_strides_a=[0, 0, 1])
+
+        test([8, 1, 6, 1], [7, 1, 5],
+             expected_result=[8, 7, 6, 5],
+             expected_strides_b=[0, 5, 5, 1])
+
 class TestMicroArray(object):
     @py.test.mark.xfail # XXX: return types changed
     def test_index2strides(self, space):
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.