Commits

Maciej Fijalkowski committed c5926e6

progress;

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_iter.py

     def __init__(self, arr):
         self.arr = arr.get_concrete()
 
+    def extend_shape(self, shape):
+        shape.extend(self.arr.shape)
+
 class BoolArrayChunk(BaseChunk):
     def __init__(self, arr):
         self.arr = arr.get_concrete()

pypy/module/micronumpy/interp_numarray.py

 from pypy.tool.sourcetools import func_with_new_name
 from pypy.rlib.rstring import StringBuilder
 from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
-     SkipLastAxisIterator, Chunk, ViewIterator
+     SkipLastAxisIterator, Chunk, ViewIterator, BoolArrayChunk, IntArrayChunk
 
 numpy_driver = jit.JitDriver(
     greens=['shapelen', 'sig'],
                 n_old_elems_to_use *= old_shape[oldI]
     return new_strides
 
+def wrap_chunk(space, w_idx, size):
+    if (space.isinstance_w(w_idx, space.w_int) or
+        space.isinstance_w(w_idx, space.w_slice)):
+        return Chunk(*space.decode_index4(w_idx, size))
+    arr = convert_to_array(space, w_idx)
+    if arr.find_dtype().is_bool_type():
+        return BoolArrayChunk(arr)
+    elif arr.find_dtype().is_int_type():
+        return IntArrayChunk(arr)
+    raise OperationError(space.w_IndexError, space.wrap("arrays used as indices must be of integer (or boolean) type"))
+
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "shape", 'size']
 
         elif (space.isinstance_w(w_idx, space.w_slice) or
               space.isinstance_w(w_idx, space.w_int)):
             return False
+        if isinstance(w_idx, BaseArray):
+            return False
         lgt = space.len_w(w_idx)
         if lgt > shape_len:
             raise OperationError(space.w_IndexError,
         for w_item in space.fixedview(w_idx):
             if space.isinstance_w(w_item, space.w_slice):
                 return False
+            if isinstance(w_item, BaseArray):
+                return False
         return True
 
     @jit.unroll_safe
     def _prepare_slice_args(self, space, w_idx):
-        if (space.isinstance_w(w_idx, space.w_int) or
-            space.isinstance_w(w_idx, space.w_slice)):
-            return [Chunk(*space.decode_index4(w_idx, self.shape[0]))]
-        return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in
+        if not space.isinstance_w(w_idx, space.w_tuple):
+            return [wrap_chunk(space, w_idx, self.shape[0])]
+        return [wrap_chunk(space, w_item, self.shape[i]) for i, w_item in
                 enumerate(space.fixedview(w_idx))]
 
     def count_all_true(self, arr):
         if (isinstance(w_idx, BaseArray) and w_idx.shape == self.shape and
             w_idx.find_dtype().is_bool_type()):
             return self.getitem_filter(space, w_idx)
+        # XXX deal with a scalar
         if self._single_item_result(space, w_idx):
             concrete = self.get_concrete()
             item = concrete._index_of_single_item(space, w_idx)
         view = self.create_slice(chunks).get_concrete()
         view.setslice(space, w_value)
 
+    def force_slice(self, shape, chunks):
+        size = 1
+        for elem in shape:
+            size *= elem
+        res = W_NDimArray(size, shape, self.find_dtype())
+        xxx
+
     @jit.unroll_safe
     def create_slice(self, chunks):
         shape = []
         s = i + 1
         assert s >= 0
         shape += self.shape[s:]
+        for chunk in chunks:
+            if not isinstance(chunk, Chunk):
+                return self.force_slice(shape, chunks)
         if not isinstance(self, ConcreteArray):
             return VirtualSlice(self, chunks, shape)
         r = calculate_slice_strides(self.shape, self.start, self.strides,

pypy/module/micronumpy/test/test_numarray.py

         raises(TypeError, getattr, array(3), '__array_interface__')
 
     def test_array_indexing_one_elem(self):
-        skip("not yet")
         from _numpypy import array, arange
         raises(IndexError, 'arange(3)[array([3.5])]')
         a = arange(3)[array([1])]
-        assert a == 1
-        assert a[0] == 1
+        assert a == [1]
         raises(IndexError,'arange(3)[array([15])]')
         assert arange(3)[array([-3])] == 0
         raises(IndexError,'arange(3)[array([-15])]')
-        assert arange(3)[array(1)] == 1
 
     def test_array_indexing_bool(self):
         from _numpypy import arange