Commits

Brian Kearns committed f2b222d

fix nditer getitem return types

  • Participants
  • Parent commits 9f3d775

Comments (0)

Files changed (2)

pypy/module/micronumpy/nditer.py

 from pypy.interpreter.typedef import TypeDef, GetSetProperty
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.interpreter.error import OperationError, oefmt
-from pypy.module.micronumpy import ufuncs, support
+from pypy.module.micronumpy import ufuncs, support, concrete
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
-from pypy.module.micronumpy.concrete import SliceArray
 from pypy.module.micronumpy.descriptor import decode_w_dtype
 from pypy.module.micronumpy.iterators import ArrayIter, SliceIterator
 from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
 class IteratorMixin(object):
     _mixin_ = True
 
-    def __init__(self, it, op_flags):
+    def __init__(self, nditer, it, op_flags):
+        self.nditer = nditer
         self.it = it
         self.st = it.reset()
         self.op_flags = op_flags
         self.st = self.it.next(self.st)
 
     def getitem(self, space, array):
-        return self.op_flags.get_it_item[self.index](space, array, self.it, self.st)
+        return self.op_flags.get_it_item[self.index](space, self.nditer, self.it, self.st)
 
     def setitem(self, space, array, val):
         xxx
         self.get_it_item = (get_readonly_item, get_readonly_slice)
 
 
-def get_readonly_item(space, array, it, st):
-    return space.wrap(it.getitem(st))
+def get_readonly_item(space, nditer, it, st):
+    res = concrete.ConcreteNonWritableArrayWithBase(
+        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
+    res.start = st.offset
+    return W_NDimArray(res)
 
 
-def get_readwrite_item(space, array, it, st):
-    #create a single-value view (since scalars are not views)
-    res = SliceArray(it.array.start + st.offset, [0], [0], [1], it.array, array)
-    #it.dtype.setitem(res, 0, it.getitem())
+def get_readwrite_item(space, nditer, it, st):
+    res = concrete.ConcreteArrayWithBase(
+        [], it.array.dtype, it.array.order, [], [], it.array.storage, nditer)
+    res.start = st.offset
     return W_NDimArray(res)
 
 
         if self.external_loop:
             for i in range(len(self.seq)):
                 self.iters.append(ExternalLoopIterator(
+                    self,
                     get_external_loop_iter(
                         space, self.order, self.seq[i], iter_shape),
                     self.op_flags[i]))
         else:
             for i in range(len(self.seq)):
                 self.iters.append(BoxIterator(
+                    self,
                     get_iter(
                         space, self.order, self.seq[i], iter_shape, self.dtypes[i]),
                     self.op_flags[i]))

pypy/module/micronumpy/test/test_nditer.py

 
 class AppTestNDIter(BaseNumpyAppTest):
     def test_basic(self):
-        from numpy import arange, nditer
+        from numpy import arange, nditer, ndarray
         a = arange(6).reshape(2,3)
+        i = nditer(a)
         r = []
-        for x in nditer(a):
+        for x in i:
+            assert type(x) is ndarray
+            assert x.base is i
+            assert x.shape == ()
+            assert x.strides == ()
+            exc = raises(ValueError, "x[()] = 42")
+            assert str(exc.value) == 'assignment destination is read-only'
             r.append(x)
         assert r == [0, 1, 2, 3, 4, 5]
         r = []
-
         for x in nditer(a.T):
             r.append(x)
         assert r == [0, 1, 2, 3, 4, 5]
         assert r == [0, 3, 1, 4, 2, 5]
 
     def test_readwrite(self):
-        from numpy import arange, nditer
+        from numpy import arange, nditer, ndarray
         a = arange(6).reshape(2,3)
-        for x in nditer(a, op_flags=['readwrite']):
+        i = nditer(a, op_flags=['readwrite'])
+        for x in i:
+            assert type(x) is ndarray
+            assert x.base is i
+            assert x.shape == ()
+            assert x.strides == ()
             x[...] = 2 * x
         assert (a == [[0, 2, 4], [6, 8, 10]]).all()