Commits

Anonymous committed f6e3323

add yet another signature, fix ViewIterator.apply_transformations() optimization

Comments (0)

Files changed (4)

pypy/module/micronumpy/interp_iter.py

 
     def apply_transformations(self, arr, transformations):
         v = BaseIterator.apply_transformations(self, arr, transformations)
-        if len(arr.shape) == 1:
+        if len(arr.shape) == 1 and len(v.res_shape) == 1:
             return OneDimIterator(self.offset, self.strides[0],
                                   self.res_shape[0])
         return v

pypy/module/micronumpy/interp_numarray.py

         ra = ResultArray(self, self.size, self.shape, self.res_dtype,
                                                                 self.res)
         loop.compute(ra)
+        if self.res:
+            broadcast_dims = len(self.res.shape) - len(self.shape)
+            chunks = [Chunk(0,0,0,0)] * broadcast_dims + \
+                     [Chunk(0, i, 1, i) for i in self.shape]
+            return ra.left.create_slice(chunks)
         return ra.left
 
     def force_if_needed(self):
     def __init__(self, child, size, shape, dtype, res=None, order='C'):
         if res is None:
             res = W_NDimArray(size, shape, dtype, order)
-        assert isinstance(res, BaseArray)
-        concr = res.get_concrete()
-        Call2.__init__(self, None, 'assign', shape, dtype, dtype, concr, child)
+        else:
+            assert isinstance(res, BaseArray)
+            #Make sure it is not a virtual array i.e. out=a+a
+            res = res.get_concrete()
+        Call2.__init__(self, None, 'assign', shape, dtype, dtype, res, child)
 
     def create_sig(self):
-        sig = signature.ResultSignature(self.res_dtype, self.left.create_sig(),
-                                         self.right.create_sig())
+        if self.left.shape != self.right.shape:
+            sig = signature.BroadcastResultSignature(self.res_dtype,
+                        self.left.create_sig(), self.right.create_sig())
+        else:
+            sig = signature.ResultSignature(self.res_dtype, 
+                        self.left.create_sig(), self.right.create_sig())
         return sig
 
 def done_if_true(dtype, val):

pypy/module/micronumpy/signature.py

 
         assert isinstance(arr, ResultArray)
         offset = frame.get_final_iter().offset
-        arr.left.setitem(offset, self.right.eval(frame, arr.right))
+        val = self.right.eval(frame, arr.right)
+        arr.left.setitem(offset, val)
+
+class BroadcastResultSignature(ResultSignature):
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import ResultArray
+
+        assert isinstance(arr, ResultArray)
+        rtransforms = transforms + [BroadcastTransform(arr.left.shape)]
+        self.left._create_iter(iterlist, arraylist, arr.left, transforms)
+        self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
 
 class BroadcastLeft(Call2):
     def _invent_numbering(self, cache, allnumbers):

pypy/module/micronumpy/test/test_outarg.py

         assert (c == out).all()
         assert c.shape == a.shape
         assert c.dtype is a.dtype
+        c[0,0] = 100
+        assert out[0, 0] == 100
         raises(ValueError, 'c = add(a, a, out=out[1])')
+        c = add(a[0], a[1], out=out[1])
+        assert (c == out[1]).all()
+        assert (c == [4, 6]).all()
+        c = add(a[0], a[1], out=out)
+        assert (c == out[1]).all()
+        assert (c == out[0]).all()
+        
 
     def test_ufunc_cast(self):
         from _numpypy import array, negative
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.