Commits

Maciej Fijalkowski committed 878c3e6

the simplest setitem

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

         res = W_NDimArray.from_shape(shape, self.get_dtype(), self.get_order())
         return loop.getitem_array_int(space, self, res, iter_shape, indexes)
 
+    def setitem_array_int(self, space, w_index, w_value):
+        val_arr = convert_to_array(space, w_value)
+        iter_shape, indexes = self._prepare_array_index(space, w_index)
+        return loop.setitem_array_int(space, self, iter_shape, indexes, val_arr)
+
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, W_NDimArray) and
             w_idx.get_dtype().is_bool_type()):
             w_idx.get_dtype().is_bool_type()):
             return self.setitem_filter(space, w_idx,
                                        convert_to_array(space, w_value))
-        self.implementation.descr_setitem(space, w_idx, w_value)
+        try:
+            self.implementation.descr_setitem(space, w_idx, w_value)
+        except ArrayArgumentException:
+            self.setitem_array_int(space, w_idx, w_value)
 
     def descr_len(self, space):
         shape = self.get_shape()

pypy/module/micronumpy/loop.py

         index_iter.next()
         res_iter.next()
     return res
+
+def setitem_array_int(space, arr, iter_shape, indexes, val_arr):
+    assert len(indexes) == 1
+    assert len(iter_shape) == 1
+    index_iter = indexes[0].create_iter()
+    dtype = arr.get_dtype()
+    val_iter = val_arr.create_iter(iter_shape)
+    while not index_iter.done():
+        idx = space.int_w(index_iter.getitem())
+        arr.setitem(space, [idx], val_iter.getitem().convert_to(dtype))
+        val_iter.next()
+        index_iter.next()

pypy/module/micronumpy/test/test_numarray.py

     def test_int_array_index(self):
         from numpypy import array, arange
         b = arange(10)[array([3, 2, 1, 5])]
-        print b
         assert (b == [3, 2, 1, 5]).all()
         raises(IndexError, "arange(10)[array([10])]")
         assert (arange(10)[[-5, -3]] == [5, 7]).all()
         raises(IndexError, "arange(10)[[-11]]")
 
+    def test_int_array_index_setitem(self):
+        from numpypy import array, arange, zeros
+        a = arange(10)
+        a[[3, 2, 1, 5]] = zeros(4, dtype=int)
+        assert (a == [0, 0, 0, 0, 4, 0, 6, 7, 8, 9]).all()
+        a[[-9, -8]] = [1, 1]
+        assert (a == [0, 1, 1, 0, 4, 0, 6, 7, 8, 9]).all()
+        raises(IndexError, "arange(10)[array([10])] = 3")
+        raises(IndexError, "arange(10)[[-11]] = 3")
+
     def test_bool_array_index(self):
         from numpypy import arange, array
         b = arange(10)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.