Commits

Pypy  committed 2d41ac5 Merge

Merged in MichaelBlume/pypy/newindex (pull request #59)

  • Participants
  • Parent commits 3800be5, 60deed4

Comments (0)

Files changed (5)

File lib_pypy/numpypy/core/numeric.py

 import _numpypy as multiarray # ARGH
 from numpypy.core.arrayprint import array2string
 
-
+newaxis = None
 
 def asanyarray(a, dtype=None, order=None, maskna=None, ownmaskna=False):
     """
 False_ = bool_(False)
 True_ = bool_(True)
 e = math.e
-pi = math.pi
+pi = math.pi

File pypy/module/micronumpy/interp_iter.py

 # structures to describe slicing
 
 class Chunk(object):
+    axis_step = 1
     def __init__(self, start, stop, step, lgt):
         self.start = start
         self.stop = stop
         return 'Chunk(%d, %d, %d, %d)' % (self.start, self.stop, self.step,
                                           self.lgt)
 
+class NewAxisChunk(Chunk):
+    start = 0
+    stop = 1
+    step = 1
+    lgt = 1
+    axis_step = 0
+
+    def __init__(self):
+        pass
+
 class BaseTransform(object):
     pass
 

File pypy/module/micronumpy/interp_numarray.py

 from pypy.module.micronumpy.appbridge import get_appbridge_cache
 from pypy.module.micronumpy.dot import multidim_dot, match_dot_shapes
 from pypy.module.micronumpy.interp_iter import (ArrayIterator,
-    SkipLastAxisIterator, Chunk, ViewIterator)
+    SkipLastAxisIterator, Chunk, NewAxisChunk, ViewIterator)
 from pypy.module.micronumpy.strides import (calculate_slice_strides,
     shape_agreement, find_shape_and_elems, get_shape_from_iterable,
-    calc_new_strides, to_coords)
+    calc_new_strides, to_coords, enumerate_chunks)
 from pypy.rlib import jit
 from pypy.rlib.rstring import StringBuilder
 from pypy.rpython.lltypesystem import lltype, rffi
         is a list of scalars that match the size of shape
         """
         shape_len = len(self.shape)
+        if space.isinstance_w(w_idx, space.w_tuple):
+            for w_item in space.fixedview(w_idx):
+                if (space.isinstance_w(w_item, space.w_slice) or
+                    space.isinstance_w(w_item, space.w_NoneType)):
+                    return False
+        elif space.isinstance_w(w_idx, space.w_NoneType):
+            return False
         if shape_len == 0:
             raise OperationError(space.w_IndexError, space.wrap(
                 "0-d arrays can't be indexed"))
         if lgt > shape_len:
             raise OperationError(space.w_IndexError,
                                  space.wrap("invalid index"))
-        if lgt < shape_len:
-            return False
-        for w_item in space.fixedview(w_idx):
-            if space.isinstance_w(w_item, space.w_slice):
-                return False
-        return True
+        return lgt == shape_len
 
     @jit.unroll_safe
     def _prepare_slice_args(self, space, w_idx):
         if (space.isinstance_w(w_idx, space.w_int) or
             space.isinstance_w(w_idx, space.w_slice)):
             return [Chunk(*space.decode_index4(w_idx, self.shape[0]))]
-        return [Chunk(*space.decode_index4(w_item, self.shape[i])) for i, w_item in
-                enumerate(space.fixedview(w_idx))]
+        elif space.isinstance_w(w_idx, space.w_NoneType):
+            return [NewAxisChunk()]
+        result = []
+        i = 0
+        for w_item in space.fixedview(w_idx):
+            if space.isinstance_w(w_item, space.w_NoneType):
+                result.append(NewAxisChunk())
+            else:
+                result.append(Chunk(*space.decode_index4(w_item,
+                                                         self.shape[i])))
+                i += 1
+        return result
 
     def count_all_true(self, arr):
         sig = arr.find_sig()
     def create_slice(self, chunks):
         shape = []
         i = -1
-        for i, chunk in enumerate(chunks):
+        for i, chunk in enumerate_chunks(chunks):
             chunk.extend_shape(shape)
         s = i + 1
         assert s >= 0

File pypy/module/micronumpy/strides.py

 from pypy.rlib import jit
 from pypy.interpreter.error import OperationError
 
+def enumerate_chunks(chunks):
+    result = []
+    i = -1
+    for chunk in chunks:
+        i += chunk.axis_step
+        result.append((i, chunk))
+    return result
+
 @jit.look_inside_iff(lambda shape, start, strides, backstrides, chunks:
     jit.isconstant(len(chunks))
 )
     rstart = start
     rshape = []
     i = -1
-    for i, chunk in enumerate(chunks):
+    for i, chunk in enumerate_chunks(chunks):
         if chunk.step != 0:
             rstrides.append(strides[i] * chunk.step)
             rbackstrides.append(strides[i] * (chunk.lgt - 1) * chunk.step)

File pypy/module/micronumpy/test/test_numarray.py

         assert a[1] == 0.
         assert a[3] == 0.
 
+    def test_newaxis(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+        a = array(range(5))
+        b = array([range(5)])
+        assert (a[newaxis] == b).all()
+
+    def test_newaxis_slice(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+
+        a = array(range(5))
+        b = array(range(1,5))
+        c = array([range(1,5)])
+        d = array([[x] for x in range(1,5)])
+
+        assert (a[1:] == b).all()
+        assert (a[1:,newaxis] == d).all()
+        assert (a[newaxis,1:] == c).all()
+
+    def test_newaxis_assign(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+
+        a = array(range(5))
+        a[newaxis,1] = [2]
+        assert a[1] == 2
+
+    def test_newaxis_virtual(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+
+        a = array(range(5))
+        b = (a + a)[newaxis]
+        c = array([[0, 2, 4, 6, 8]])
+        assert (b == c).all()
+
+    def test_newaxis_then_slice(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+        a = array(range(5))
+        b = a[newaxis]
+        assert (b[0,1:] == a[1:]).all()
+
+    def test_slice_then_newaxis(self):
+        from _numpypy import array
+        from numpypy.core.numeric import newaxis
+        a = array(range(5))
+        b = a[2:]
+        assert (b[newaxis] == [[2, 3, 4]]).all()
+
     def test_scalar(self):
         from _numpypy import array, dtype
         a = array(3)