Ilya Osadchiy avatar Ilya Osadchiy committed fc54fc8

Initial (unoptimized) impementation of indexing by boolean vectors.

Comments (0)

Files changed (2)

pypy/module/micronumpy/interp_numarray.py

             bool_dtype = space.fromcache(interp_dtype.W_BoolDtype)
             int_dtype = space.fromcache(interp_dtype.W_Int64Dtype)
             if w_idx.find_dtype() is bool_dtype:
-                # TODO: indexing by bool array
-                raise NotImplementedError("sorry, not yet implemented")
+                # Indexing by boolean array
+                new_sig = signature.Signature.find_sig([
+                    IndexedByBoolArray.signature, self.signature
+                ])                
+                res = IndexedByBoolArray(new_sig, bool_dtype, self, w_idx)
+                return space.wrap(res)
             else:
                 # Indexing by array
 
         val = self.source.eval(idx).convert_to(self.res_dtype)
         return val
 
+class IndexedByBoolArray(VirtualArray):
+    """
+    Intermediate class for performing indexing of array by another array
+    """
+    # TODO: override "compute" to optimize (?)
+    signature = signature.BaseSignature()
+    def __init__(self, signature, bool_dtype, source, index):
+        VirtualArray.__init__(self, signature, source.find_dtype())
+        self.source = source
+        self.index = index
+        self.bool_dtype = bool_dtype
+        self.size = -1
+        self.cur_idx = 0
+
+    def _del_sources(self):
+        self.source = None
+        self.index = None
+
+    def _find_size(self):
+        # Finding size may be long, so we store the result for reuse.
+        if self.size != -1:
+            return self.size
+        # TODO: avoid index.get_concrete by using "sum" (reduce with "add")
+        idxs = self.index.get_concrete()
+        s = 0
+        i = 0
+        while i < self.index.find_size():
+            idx_val = self.bool_dtype.unbox(idxs.eval(i).convert_to(self.bool_dtype))
+            assert(isinstance(idx_val, bool))
+            if idx_val is True:
+                s += 1
+            i += 1
+        self.size = s
+        return self.size
+
+    def _eval(self, i):
+        if i == 0:
+            self.cur_idx = 0
+        while True:
+            idx_val = self.bool_dtype.unbox(self.index.eval(self.cur_idx).convert_to(self.bool_dtype))
+            assert(isinstance(idx_val, bool))
+            if idx_val is True:
+                break
+            self.cur_idx += 1
+        val = self.source.eval(self.cur_idx).convert_to(self.res_dtype)
+        self.cur_idx += 1
+        return val
+
 class ViewArray(BaseArray):
     """
     Class for representing views of arrays, they will reflect changes of parent

pypy/module/micronumpy/test/test_numarray.py

         for i in xrange(6):
             assert a_by_list[i] == range(5)[idx_list[i]]
 
+    def test_index_by_bool_array(self):
+        from numpy import array, dtype
+        a = array(range(5))
+        ind = array([False, True, False, True, False])
+        assert ind.dtype is dtype(bool)
+        # get length before actual calculation
+        b0 = a[ind]
+        assert len(b0) == 2
+        assert b0[0] == 1
+        assert b0[1] == 3
+        # get length after actual calculation
+        b1 = a[ind]
+        assert b1[0] == 1
+        assert b1[1] == 3
+        assert len(b1) == 2
+
     def test_setitem(self):
         from numpy import array
         a = array(range(5))
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.