Justin Peel avatar Justin Peel committed 67245c6

bool indexing requires a numpy array and matching size. use array's dtype if it is an ind for integer indexing.

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_numarray.py

                                      space.wrap("invalid index"))
             w_idx = space.getitem(w_idx, space.wrap(0))
         elif space.issequence_w(w_idx):
-            w_idx = convert_to_array(space, w_idx)
-            bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
-            int_dtype = space.fromcache(interp_dtype.W_Int64Dtype)
-            if w_idx.find_dtype() is bool_dtype:
-                # Indexing by boolean array
+            if isinstance(w_idx, BaseArray) and \
+                w_idx.find_dtype().kind == interp_dtype.BOOLLTR:
+                # Indexing by boolean array - must be using a bool numpy
+                # array of the exact same size as self to do this.
+                # Can't use a list.
+                bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
                 new_sig = signature.Signature.find_sig([
                     IndexedByBoolArray.signature, self.signature
-                ])                
+                ])
+                # will be more complicated with multi-dim arrays
+                if self.find_size() != w_idx.find_size():
+                    raise OperationError(space.w_ValueError,
+                        space.wrap("bool indexing requires matching dims"))
                 res = IndexedByBoolArray(new_sig, bool_dtype, self, w_idx)
                 return space.wrap(res)
             else:
+                w_idx = convert_to_array(space, w_idx)
                 # Indexing by array
-
                 # FIXME: should raise exception if any index in
                 # array is out od bound, but this kills lazy execution
                 new_sig = signature.Signature.find_sig([
                     IndexedByArray.signature, self.signature
-                ])                
-                res = IndexedByArray(new_sig, int_dtype, self, w_idx)
+                ])
+                # Use w_idx's dtype if possible
+                dtype = w_idx.find_dtype()
+                if dtype.kind != interp_dtype.SIGNEDLTR \
+                        and dtype.kind != interp_dtype.UNSIGNEDLTR:
+                    dtype = space.fromcache(interp_dtype.W_Int64Dtype)
+                res = IndexedByArray(new_sig, dtype, self, w_idx)
                 return space.wrap(res)
 
         start, stop, step, slice_length = space.decode_index4(w_idx, self.find_size())

pypy/module/micronumpy/test/test_numarray.py

     def test_index_by_bool_array(self):
         from numpy import array, dtype
         a = array(range(5))
+        ind = array([False,True])
+        raises(ValueError, "a[ind]")
         ind = array([False, True, False, True, False])
         assert ind.dtype is dtype(bool)
         # get length before actual calculation
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.