Commits

Maciej Fijalkowski  committed e614e15

start fancy indexing

  • Participants
  • Parent commits 63801ee
  • Branches numpy-fancy-indexing

Comments (0)

Files changed (6)

File pypy/module/micronumpy/arrayimpl/concrete.py

 
 from pypy.module.micronumpy.arrayimpl import base
 from pypy.module.micronumpy import support, loop
-from pypy.module.micronumpy.base import convert_to_array, W_NDimArray
+from pypy.module.micronumpy.base import convert_to_array, W_NDimArray,\
+     ArrayArgumentException
 from pypy.module.micronumpy.strides import calc_new_strides, shape_agreement,\
      calculate_broadcast_strides, calculate_dot_strides
 from pypy.module.micronumpy.iter import Chunk, Chunks, NewAxisChunk, RecordChunk
             item += idx * self.strides[i]
         return item
 
+    @jit.unroll_safe
+    def _lookup_by_unwrapped_index(self, space, lst):
+        item = self.start
+        for i, idx in enumerate(lst):
+            if idx < 0:
+                idx = self.shape[i] + idx
+            if idx < 0 or idx >= self.shape[i]:
+                raise operationerrfmt(space.w_IndexError,
+                      "index (%d) out of range (0<=index<%d", i, self.shape[i],
+                )
+            item += idx * self.strides[i]
+        return item
+
+    def getitem_index(self, space, index):
+        return self.getitem(self._lookup_by_unwrapped_index(space, index))
+
+    def setitem_index(self, space, index, value):
+        self.setitem(self._lookup_by_unwrapped_index(space, index), value)
+
     def _single_item_index(self, space, w_idx):
         """ Return an index of single item if possible, otherwise raises
         IndexError
             space.isinstance_w(w_idx, space.w_slice) or
             space.is_w(w_idx, space.w_None)):
             raise IndexError
+        if isinstance(w_idx, W_NDimArray):
+            raise ArrayArgumentException
         shape_len = len(self.shape)
         if shape_len == 0:
             raise OperationError(space.w_IndexError, space.wrap(
                 "0-d arrays can't be indexed"))
+        view_w = None
+        if (space.isinstance_w(w_idx, space.w_list) or
+            isinstance(w_idx, W_NDimArray)):
+            raise ArrayArgumentException
         if space.isinstance_w(w_idx, space.w_tuple):
             view_w = space.fixedview(w_idx)
             if len(view_w) < shape_len:
                 for w_item in view_w:
                     if space.is_w(w_item, space.w_None):
                         count -= 1
+                    if (space.isinstance_w(w_item, space.w_list) or
+                        isinstance(w_item, W_NDimArray)):
+                        raise ArrayArgumentException
                 if count == shape_len:
                     raise IndexError # but it's still not a single item
                 raise OperationError(space.w_IndexError,

File pypy/module/micronumpy/arrayimpl/scalar.py

         raise OperationError(space.w_IndexError,
                              space.wrap("scalars cannot be indexed"))
 
+    def getitem_index(self, space, idx):
+        raise OperationError(space.w_IndexError,
+                             space.wrap("scalars cannot be indexed"))
+
     def descr_setitem(self, space, w_idx, w_val):
         raise OperationError(space.w_IndexError,
                              space.wrap("scalars cannot be indexed"))
         
+    def setitem_index(self, space, idx, w_val):
+        raise OperationError(space.w_IndexError,
+                             space.wrap("scalars cannot be indexed"))
     def set_shape(self, space, new_shape):
         if not new_shape:
             return self

File pypy/module/micronumpy/base.py

 from pypy.tool.pairtype import extendabletype
 from pypy.module.micronumpy.support import calc_strides
 
+class ArrayArgumentException(Exception):
+    pass
+
 class W_NDimArray(Wrappable):
     __metaclass__ = extendabletype
 

File pypy/module/micronumpy/interp_numarray.py

 from pypy.interpreter.error import operationerrfmt, OperationError
 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec
-from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
+from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\
+     ArrayArgumentException
 from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
      get_shape_from_iterable, to_coords
     def setitem_filter(self, space, idx, val):
         loop.setitem_filter(self, idx, val)
 
+    def _prepare_array_index(self, space, w_index):
+        if isinstance(w_index, W_NDimArray):
+            return w_index.get_shape(), [w_index]
+        w_lst = space.listview(w_index)
+        for w_item in w_lst:
+            if not space.isinstance_w(w_item, space.w_int):
+                break
+        else:
+            arr = convert_to_array(space, w_index)
+            return arr.get_shape(), [arr]
+        xxx
+
+    def getitem_array_int(self, space, w_index):
+        iter_shape, indexes = self._prepare_array_index(space, w_index)
+        shape = iter_shape + self.get_shape()[len(indexes):]
+        res = W_NDimArray.from_shape(shape, self.get_dtype(), self.get_order())
+        return loop.getitem_array_int(space, self, res, iter_shape, indexes)
+
     def descr_getitem(self, space, w_idx):
         if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
             w_idx.get_dtype().is_bool_type()):
             return self.getitem_filter(space, w_idx)
         try:
             return self.implementation.descr_getitem(space, w_idx)
+        except ArrayArgumentException:
+            return self.getitem_array_int(space, w_idx)
         except OperationError:
             raise OperationError(space.w_IndexError, space.wrap("wrong index"))
 
+    def getitem(self, space, index_list):
+        return self.implementation.getitem_index(space, index_list)
+
+    def setitem(self, space, index_list, w_value):
+        self.implementation.setitem_index(space, index_list, w_value)
+
     def descr_setitem(self, space, w_idx, w_value):
         if (isinstance(w_idx, W_NDimArray) and w_idx.get_shape() == self.get_shape() and
             w_idx.get_dtype().is_bool_type()):
             if self.is_scalar():
                 return self.get_scalar_value().item(space)
             if self.get_size() == 1:
-                w_obj = self.descr_getitem(space,
-                                           space.newtuple([space.wrap(0) for i
-                                      in range(len(self.get_shape()))]))
+                w_obj = self.getitem(space,
+                                     [0] * range(len(self.get_shape())))
                 assert isinstance(w_obj, interp_boxes.W_GenericBox)
                 return w_obj.item(space)
             raise OperationError(space.w_IndexError,
                 raise OperationError(space.w_IndexError,
                                      space.wrap("index out of bounds"))
             i = self.to_coords(space, w_arg)
-            item = self.descr_getitem(space, space.newtuple([space.wrap(x)
-                                             for x in i]))
+            item = self.getitem(space, i)
             assert isinstance(item, interp_boxes.W_GenericBox)
             return item.item(space)
         raise OperationError(space.w_NotImplementedError, space.wrap(

File pypy/module/micronumpy/loop.py

             builder.append(res_str_casted[i])
         iter.next()
     return builder.build()
+
+def getitem_array_int(space, arr, res, iter_shape, indexes):
+    assert len(indexes) == 1
+    assert len(iter_shape) == 1
+    res_iter = res.create_iter() # this shape is whatever shape res comes in
+    index_iter = indexes[0].create_iter()
+    while not index_iter.done():
+        idx = space.int_w(index_iter.getitem())
+        res_iter.setitem(arr.getitem(space, [idx]))
+        index_iter.next()
+        res_iter.next()
+    return res

File pypy/module/micronumpy/test/test_numarray.py

 
     def test_int_array_index(self):
         from numpypy import array, arange
-        assert (arange(10)[array([3, 2, 1, 5])] == [3, 2, 1, 5]).all()
+        b = arange(10)[array([3, 2, 1, 5])]
+        print b
+        assert (b == [3, 2, 1, 5]).all()
         raises(IndexError, "arange(10)[array([10])]")
         assert (arange(10)[[-5, -3]] == [5, 7]).all()
         raises(IndexError, "arange(10)[[-11]]")
-                        
+
+    def test_bool_array_index(self):
+        xxx
+
 class AppTestMultiDim(BaseNumpyAppTest):
     def test_init(self):
         import _numpypy