Commits

Konstantin Lopuhin committed 1098dd9

add 3d secialized iterator too

Comments (0)

Files changed (1)

pypy/module/micronumpy/iter.py

     if len(shape) == 2:
         return TwoDimViewIterator(
                 array, start, strides, backstrides, shape)
+    elif len(shape) == 3:
+        return ThreeDimViewIterator(
+                array, start, strides, backstrides, shape)
     else:
         return BaseMultiDimViewIterator(
                 array, start, strides, backstrides, shape)
             self.indexes0 = self.indexes0 +  this_i_step
         self._done = True
 
-    @property
-    def indexes(self):
-        return [self.indexes0, self.indexes1]
+   #@property
+   #def indexes(self):
+   #    return [self.indexes0, self.indexes1]
+
+class ThreeDimViewIterator(BaseMultiDimViewIterator):
+    def __init__(self, array, start, strides, backstrides, shape):
+        self.array = array
+        self.offset = start
+        self._done = product(shape) == 0
+        self.size = array.size
+        self.shape0, self.shape1, self.shape2 = self.shape = shape
+        self.backstrides0, self.backstrides1, self.backstrides2 = backstrides
+        self.strides0, self.strides1, self.strides2 = strides
+        self.indexes0 = self.indexes1 = self.indexes2 = 0
+
+    def next(self):
+        if self.indexes2 < self.shape2 - 1:
+            self.indexes2 += 1
+            self.offset += self.strides2
+            return
+        else:
+            self.indexes2 = 0
+            self.offset -= self.backstrides2
+        if self.indexes1 < self.shape1 - 1:
+            self.indexes1 += 1
+            self.offset += self.strides1
+            return
+        else:
+            self.indexes1 = 0
+            self.offset -= self.backstrides1
+        if self.indexes0 < self.shape0 - 1:
+            self.indexes0 += 1
+            self.offset += self.strides0
+            return
+        else:
+            self.indexes0 = 0
+            self.offset -= self.backstrides0
+        self._done = True
+
+    def next_skip_x(self, step):
+        if self.indexes2 < self.shape2 - step:
+            self.indexes2 += step
+            self.offset += self.strides2 * step
+            return
+        else:
+            remaining_step = (self.indexes2 + step) // self.shape2
+            this_i_step = step - remaining_step * self.shape2
+            self.offset += self.strides2 * this_i_step
+            self.indexes2 = self.indexes2 +  this_i_step
+            step = remaining_step
+        if self.indexes1 < self.shape1 - step:
+            self.indexes1 += step
+            self.offset += self.strides1 * step
+            return
+        else:
+            remaining_step = (self.indexes1 + step) // self.shape1
+            this_i_step = step - remaining_step * self.shape1
+            self.offset += self.strides1 * this_i_step
+            self.indexes1 = self.indexes1 +  this_i_step
+            step = remaining_step
+        if self.indexes0 < self.shape0 - step:
+            self.indexes0 += step
+            self.offset += self.strides0 * step
+            return
+        else:
+            remaining_step = (self.indexes0 + step) // self.shape0
+            this_i_step = step - remaining_step * self.shape0
+            self.offset += self.strides0 * this_i_step
+            self.indexes0 = self.indexes0 +  this_i_step
+        self._done = True
 
 class AxisIterator(base.BaseArrayIterator):
     def __init__(self, array, shape, dim, cumulative):