1. Pypy
  2. Untitled project
  3. pypy

Commits

Matti Picus  committed 7f18238

add broadcast_backwards to allow indexing by a lower-shaped array

  • Participants
  • Parent commits 47a587e
  • Branches indexing-by-array

Comments (0)

Files changed (5)

File pypy/module/micronumpy/arrayimpl/concrete.py

View file
  • Ignore whitespace
         self.backstrides = backstrides
         self.storage = storage
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         if shape is None or shape == self.get_shape():
             return iter.ConcreteArrayIterator(self)
         r = calculate_broadcast_strides(self.get_strides(),
                                         self.get_backstrides(),
-                                        self.get_shape(), shape)
+                                        self.get_shape(), shape, backward_broadcast)
         return iter.MultiDimViewIterator(self, self.dtype, 0, r[0], r[1], shape)
 
     def fill(self, box):
     def fill(self, box):
         loop.fill(self, box.convert_to(self.dtype))
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         if shape is not None and shape != self.get_shape():
             r = calculate_broadcast_strides(self.get_strides(),
                                             self.get_backstrides(),
-                                            self.get_shape(), shape)
+                                            self.get_shape(), shape,
+                                            backward_broadcast)
             return iter.MultiDimViewIterator(self.parent, self.dtype,
                                              self.start, r[0], r[1], shape)
         if len(self.get_shape()) == 1:

File pypy/module/micronumpy/interp_numarray.py

View file
  • Ignore whitespace
         s.append('])')
         return s.build()
 
-    def create_iter(self, shape=None):
+    def create_iter(self, shape=None, backward_broadcast=False):
         assert isinstance(self.implementation, BaseArrayImplementation)
-        return self.implementation.create_iter(shape)
+        return self.implementation.create_iter(shape,
+                                   backward_broadcast=backward_broadcast)
 
     def create_axis_iter(self, shape, dim, cum):
         return self.implementation.create_axis_iter(shape, dim, cum)

File pypy/module/micronumpy/loop.py

View file
  • Ignore whitespace
 
 def getitem_filter(res, arr, index):
     res_iter = res.create_iter()
-    index_iter = index.create_iter(arr.get_shape())
+    index_iter = index.create_iter(arr.get_shape(), backward_broadcast=True)
     arr_iter = arr.create_iter()
     shapelen = len(arr.get_shape())
     arr_dtype = arr.get_dtype()
     index_dtype = index.get_dtype()
     # XXX length of shape of index as well?
     while not index_iter.done():
-        print 'res,arr,index', res_iter.offset, arr_iter.offset, index_iter.offset, index_iter.getitem_bool()
         getitem_filter_driver.jit_merge_point(shapelen=shapelen,
                                               index_dtype=index_dtype,
                                               arr_dtype=arr_dtype,

File pypy/module/micronumpy/strides.py

View file
  • Ignore whitespace
     rshape += shape[s:]
     return rshape, rstart, rstrides, rbackstrides
 
-def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape):
+def calculate_broadcast_strides(strides, backstrides, orig_shape, res_shape, backwards=False):
     rstrides = []
     rbackstrides = []
     for i in range(len(orig_shape)):
         else:
             rstrides.append(strides[i])
             rbackstrides.append(backstrides[i])
-    rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
-    rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
+    if backwards:
+        rstrides = rstrides + [0] * (len(res_shape) - len(orig_shape))  
+        rbackstrides = rbackstrides + [0] * (len(res_shape) - len(orig_shape)) 
+    else:
+        rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
+        rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
     return rstrides, rbackstrides
 
 def is_single_elem(space, w_elem, is_rec_type):

File pypy/module/micronumpy/test/test_numarray.py

View file
  • Ignore whitespace
         raises(ValueError, "b[array([[True, False], [True, False]])]")
         a = array([[1,2,3],[4,5,6],[7,8,9]],int)
         c = array([True,False,True],bool)
-        print 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
         b = a[c]
-        print 'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'
         assert (a[c] == [[1, 2, 3], [7, 8, 9]]).all()
 
     def test_bool_array_index_setitem(self):