Commits

Brian Kearns committed f5dbd42

fix indexing using array scalars

Comments (0)

Files changed (3)

pypy/module/micronumpy/arrayimpl/concrete.py

                     "field named %s not found" % idx))
             return RecordChunk(idx)
         if (space.isinstance_w(w_idx, space.w_int) or
-            space.isinstance_w(w_idx, space.w_slice)):
+                space.isinstance_w(w_idx, space.w_slice)):
+            return Chunks([Chunk(*space.decode_index4(w_idx, self.get_shape()[0]))])
+        elif isinstance(w_idx, W_NDimArray) and \
+                isinstance(w_idx.implementation, scalar.Scalar):
+            w_idx = w_idx.get_scalar_value().item(space)
+            if not space.isinstance_w(w_idx, space.w_int) and \
+                    not space.isinstance_w(w_idx, space.w_bool):
+                raise OperationError(space.w_IndexError, space.wrap(
+                    "arrays used as indices must be of integer (or boolean) type"))
             return Chunks([Chunk(*space.decode_index4(w_idx, self.get_shape()[0]))])
         elif space.is_w(w_idx, space.w_None):
             return Chunks([NewAxisChunk()])

pypy/module/micronumpy/interp_numarray.py

                                prefix)
 
     def descr_getitem(self, space, w_idx):
-        if isinstance(w_idx, W_NDimArray) and w_idx.get_dtype().is_bool_type():
+        if isinstance(w_idx, W_NDimArray) and w_idx.get_dtype().is_bool_type() \
+                and len(w_idx.get_shape()) > 0:
             return self.getitem_filter(space, w_idx)
         try:
             return self.implementation.descr_getitem(space, self, w_idx)

pypy/module/micronumpy/test/test_numarray.py

         raises(IndexError, "arange(10)[array([10])] = 3")
         raises(IndexError, "arange(10)[[-11]] = 3")
 
-    def test_bool_single_index(self):
+    def test_array_scalar_index(self):
         import numpypy as np
         a = np.array([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
-        a[np.array(True)]; skip("broken")  # check for crash but skip rest of test until correct
+        assert (a[np.array(0)] == a[0]).all()
+        assert (a[np.array(1)] == a[1]).all()
+        exc = raises(IndexError, "a[np.array(1.1)]")
+        assert exc.value.message == 'arrays used as indices must be of ' \
+                                    'integer (or boolean) type'
         assert (a[np.array(True)] == a[1]).all()
         assert (a[np.array(False)] == a[0]).all()