Commits

Anonymous committed b3836fc

add BroadcastUfunc iter, more tests pass

Comments (0)

Files changed (3)

pypy/module/micronumpy/interp_numarray.py

         if self.forced_result is not None:
             return self.forced_result.create_sig()
         if self.shape != self.values.shape:
-            xxx 
+            #This happens if out arg is used
+            return signature.BroadcastUfunc(self.ufunc, self.name,
+                                            self.calc_dtype,
+                                            self.values.create_sig(),
+                                            self.res.create_sig())
         return signature.Call1(self.ufunc, self.name, self.calc_dtype,
                                self.values.create_sig())
 
         if res is None:
             res = W_NDimArray(size, shape, dtype, order)
         assert isinstance(res, BaseArray)
-        Call2.__init__(self, None, 'assign', shape, dtype, dtype, res, child)
+        concr = res.get_concrete()
+        Call2.__init__(self, None, 'assign', shape, dtype, dtype, concr, child)
 
     def create_sig(self):
         sig = signature.ResultSignature(self.res_dtype, self.left.create_sig(),

pypy/module/micronumpy/signature.py

         return self.child.eval(frame, arr.child)
 
 class Call1(Signature):
-    _immutable_fields_ = ['unfunc', 'name', 'child', 'dtype']
+    _immutable_fields_ = ['unfunc', 'name', 'child', 'res', 'dtype']
 
-    def __init__(self, func, name, dtype, child):
+    def __init__(self, func, name, dtype, child, res=None):
         self.unfunc = func
         self.child = child
         self.name = name
         self.dtype = dtype
+        self.res  = res
 
     def hash(self):
         return compute_hash(self.name) ^ intmask(self.child.hash() << 1)
         v = self.child.eval(frame, arr.values).convert_to(arr.calc_dtype)
         return self.unfunc(arr.calc_dtype, v)
 
+
+class BroadcastUfunc(Call1):
+    def _invent_numbering(self, cache, allnumbers):
+        self.res._invent_numbering(cache, allnumbers)
+        self.child._invent_numbering(new_cache(), allnumbers)
+
+    def debug_repr(self):
+        return 'BroadcastUfunc(%s, %s)' % (self.name, self.child.debug_repr())
+
+    def _create_iter(self, iterlist, arraylist, arr, transforms):
+        from pypy.module.micronumpy.interp_numarray import Call1
+
+        assert isinstance(arr, Call1)
+        vtransforms = transforms + [BroadcastTransform(arr.values.shape)]
+        self.child._create_iter(iterlist, arraylist, arr.values, vtransforms)
+        self.res._create_iter(iterlist, arraylist, arr.res, transforms)
+
+    def eval(self, frame, arr):
+        from pypy.module.micronumpy.interp_numarray import Call1
+        assert isinstance(arr, Call1)
+        v = self.child.eval(frame, arr.values).convert_to(arr.calc_dtype)
+        return self.unfunc(arr.calc_dtype, v)
+
 class Call2(Signature):
     _immutable_fields_ = ['binfunc', 'name', 'calc_dtype', 'left', 'right']
 

pypy/module/micronumpy/test/test_outarg.py

 
     def test_ufunc_out(self):
         from _numpypy import array, negative, zeros, sin
+        from math import sin as msin
         a = array([[1, 2], [3, 4]])
         c = zeros((2,2,2))
         b = negative(a + a, out=c[1])
         assert b.shape == c.shape
         a = array([1, 2])
         b = sin(a, out=c)
-        assert(c == [[-1, -2], [-1, -2]]).all()
+        assert(c == [[msin(1), msin(2)]] * 2).all()
         b = sin(a, out=c+c)
         assert (c == b).all()
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.