mattip committed 3005bc6

wip - finding shape for iterators including external_loop flag

Comments (0)

Files changed (1)


 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.interpreter.error import OperationError
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array
-from pypy.module.micronumpy.strides import calculate_broadcast_strides
+from pypy.module.micronumpy.strides import (calculate_broadcast_strides,
+                                             shape_agreement_multiple)
 from pypy.module.micronumpy.iter import MultiDimViewIterator
 from pypy.module.micronumpy import support
 from pypy.module.micronumpy.arrayimpl.concrete import SliceArray
                 'Iterator flag EXTERNAL_LOOP cannot be used if an index or '
                 'multi-index is being tracked'))
-def get_iter(space, order, imp, backward):
+def get_iter(space, order, imp, shape):
     if order == 'K' or (order == 'C' and imp.order == 'C'):
         backward = False
     elif order =='F' and imp.order == 'C':
         # flip the strides. Is this always true for multidimension?
         strides = [s for s in imp.strides[::-1]]
         backstrides = [s for s in imp.backstrides[::-1]]
-        shape = [s for s in imp.shape[::-1]]
+        shape = [s for s in shape[::-1]]
         strides = imp.strides
         backstrides = imp.backstrides
-        shape = imp.shape
-    shape1d = [support.product(imp.shape),]
-    r = calculate_broadcast_strides(strides, backstrides, shape,
-                                    shape1d, backward)
+    r = calculate_broadcast_strides(strides, backstrides, imp.shape,
+                                    shape, backward)
     return MultiDimViewIterator(imp, imp.dtype, imp.start, r[0], r[1], shape)
         self.op_flags = parse_op_arg(space, 'op_flags', w_op_flags,
                                      len(self.seq), parse_op_flag)
+        self.shape = iter_shape = shape_agreement_multiple(space, self.seq)
+        if self.external_loop:
+            xxx find longest contiguous shape
+            iter_shape = iter_shape[1:]
         for i in range(len(self.seq)):
-            # XXX the shape of the iter depends on all the seq.shapes together
             self.iters.append(get_iter(space, self.order,
-                            self.seq[i].implementation, self.op_flags[i]))
+                            self.seq[i].implementation, iter_shape))
     def descr_iter(self, space):
         return space.wrap(self)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.