Commits

Antonio Cuni  committed c953de9

completely change the strategy for inlining jit_merge_points in the
caller. The old one did not work in complex cases involving raising graphs. In
particular, this case did not work:

def bar():
"something which cannot be inlined and raises"

@inline_in_portal
def foo():
driver.jit_merge_point()
return bar()

def fn():
try:
foo():
except StopIteration:
pass

that's because that the backendopt inliner is not able to inline calls to
raising graphs inside a try/except block.

To work around the issue, we put the actual driver.jit_merge_point in a
separate function, which needs to be called as soon as we enter foo(). Then,
we move *only* this call from foo() to fn(), and finally inline the
jit_merge_point in fn().

Next step is to provide a nice decorator to do everythin automatically

  • Participants
  • Parent commits 359bb4a
  • Branches autoreds

Comments (0)

Files changed (4)

File pypy/jit/metainterp/test/test_warmspot.py

         assert res == expected
         self.check_resops(int_sub=2, int_mul=0, int_add=2)
 
+    def test_inline_jit_merge_point(self):
+        # test that the machinery to inline jit_merge_points in callers
+        # works. The final user does not need to mess manually with the
+        # _inline_jit_merge_point_ attribute and similar, it is all nicely
+        # handled by @JitDriver.inline()
+        myjitdriver = JitDriver(greens = ['a'], reds = 'auto')
+
+        def jit_merge_point(a, b):
+            myjitdriver.jit_merge_point(a=a)
+
+        def add(a, b):
+            jit_merge_point(a, b)
+            return a+b
+        add._inline_jit_merge_point_ = jit_merge_point
+        myjitdriver.inline_jit_merge_point = True
+
+        def calc(n):
+            res = 0
+            while res < 1000:
+                res = add(n, res)
+            return res
+
+        def f():
+            return calc(1) + calc(3)
+
+        res = self.meta_interp(f, [])
+        assert res == 1000 + 1002
+        self.check_resops(int_add=4)
+
+
     def test_inline_in_portal(self):
+        py.test.skip('in-progress')
         myjitdriver = JitDriver(greens = [], reds = 'auto')
         class MyRange(object):
             def __init__(self, n):
         self.check_resops(int_eq=4, int_add=8)
         self.check_trace_count(2)
 
-    def test_inline_in_portal_exception(self):
-        myjitdriver = JitDriver(greens = [], reds = 'auto')
-        def inc(n):
-            if n == 1000:
-                raise OverflowError
-            return n+1
-
-        @myjitdriver.inline_in_portal
-        def jitted_inc(n):
-            myjitdriver.jit_merge_point()
-            return inc(n)
-
-        def f():
-            res = 0
-            while True:
-                try:
-                    res = jitted_inc(res)
-                except OverflowError:
-                    break
-            return res
-        res = self.meta_interp(f, [])
-        assert res == 1000
-        self.check_resops(int_add=2)
-
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU

File pypy/jit/metainterp/warmspot.py

 
     def inline_inlineable_portals(self):
         """
-        Find all the graphs which have been decorated with
-        @jitdriver.inline_in_portal and inline them in the callers, making
-        them JIT portals. Then, create a fresh copy of the jitdriver for each
-        of those new portals, because they cannot share the same one.  See
-        test_ajit::test_inline_in_portal.
+        Find all the graphs which have been decorated with @jitdriver.inline
+        and inline them in the callers, making them JIT portals. Then, create
+        a fresh copy of the jitdriver for each of those new portals, because
+        they cannot share the same one.  See
+        test_ajit::test_inline_jit_merge_point
         """
-        from pypy.translator.backendopt import inline
+        from pypy.translator.backendopt.inline import (
+            get_funcobj, inlinable_static_callers, auto_inlining)
 
         # find all the graphs which call an @inline_in_portal function
-        callgraph = inline.inlinable_static_callers(self.translator.graphs)
+        callgraph = inlinable_static_callers(self.translator.graphs, store_calls=True)
         new_callgraph = []
         new_portals = set()
-        for caller, callee in callgraph:
+        for caller, block, op_call, callee in callgraph:
             func = getattr(callee, 'func', None)
-            _inline_in_portal_ = getattr(func, '_inline_in_portal_', False)
-            if _inline_in_portal_:
-                new_callgraph.append((caller, callee))
+            _inline_jit_merge_point_ = getattr(func, '_inline_jit_merge_point_', None)
+            if _inline_jit_merge_point_:
+                # we are calling a function which has been decorated with
+                # @jitdriver.inline: the very first op of the callee graph is
+                # a call to the function which contains the actual
+                # jit_merge_point: fish it!
+                jmp_block, op_jmp_call = next(callee.iterblockops())
+                msg = ("The first operation of an _inline_jit_merge_point_ graph must be "
+                       "a direct_call to the function passed to @jitdriver.inline()")
+                assert op_jmp_call.opname == 'direct_call', msg
+                jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
+                assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+                #
+                # now we move the op_jmp_call from callee to caller, just
+                # before op_call. We assume that the args passed to
+                # op_jmp_call are the very same which are received by callee
+                # (i.e., the one passed to op_call)
+                assert len(op_call.args) == len(op_jmp_call.args)
+                jmp_block.operations.remove(op_jmp_call)
+                op_jmp_call.args[1:] = op_call.args[1:]
+                idx = block.operations.index(op_call)
+                block.operations.insert(idx, op_jmp_call)
+                #
+                # finally, we signal that we want to inline op_jmp_call into
+                # caller, so that finally the actuall call to
+                # driver.jit_merge_point will be seen there
+                new_callgraph.append((caller, jmp_funcobj.graph))
                 new_portals.add(caller)
 
         # inline them!
         inline_threshold = self.translator.config.translation.backendopt.inline_threshold
-        inline.auto_inlining(self.translator, inline_threshold, callgraph)
+        auto_inlining(self.translator, inline_threshold, new_callgraph)
 
         # make a fresh copy of the JitDriver in all newly created
         # jit_merge_points
             op = block.operations[pos]
             v_driver = op.args[1]
             driver = v_driver.value
-            if not driver.inlined_in_portal:
+            if not driver.inline_jit_merge_point:
                 continue
             new_driver = driver.clone()
             c_new_driver = Constant(new_driver, v_driver.concretetype)
                         alive_v.add(op1.result)
                 greens_v = op.args[2:]
                 reds_v = alive_v - set(greens_v)
+                reds_v = [v for v in reds_v if v.concretetype is not lltype.Void]
                 reds_v = support.sort_vars(reds_v)
                 op.args.extend(reds_v)
                 if jitdriver.numreds is None:

File pypy/rlib/jit.py

         return func
 
     def clone(self):
-        assert self.inlined_in_portal, 'JitDriver.clone works only after @inline_in_portal'
+        assert self.inline_jit_merge_point, 'JitDriver.clone works only after @inline'
         newdriver = object.__new__(self.__class__)
         newdriver.__dict__ = self.__dict__.copy()
         return newdriver

File pypy/translator/backendopt/inline.py

     return (0.9999 * measure_median_execution_cost(graph) +
             count), True
 
-def inlinable_static_callers(graphs):
+def inlinable_static_callers(graphs, store_calls=False):
     ok_to_call = set(graphs)
     result = []
+    def add(parentgraph, block, op, graph):
+        if store_calls:
+            result.append((parentgraph, block, op, graph))
+        else:
+            result.append((parentgraph, graph))
+    #
     for parentgraph in graphs:
         for block in parentgraph.iterblocks():
             for op in block.operations:
                         if getattr(getattr(funcobj, '_callable', None),
                                    '_dont_inline_', False):
                             continue
-                        result.append((parentgraph, graph))
+                        add(parentgraph, block, op, graph)
                 if op.opname == "oosend":
                     meth = get_meth_from_oosend(op)
                     graph = getattr(meth, 'graph', None)
                     if graph is not None and graph in ok_to_call:
-                        result.append((parentgraph, graph))
+                        add(parentgraph, block, op, graph)
     return result
     
 def instrument_inline_candidates(graphs, threshold):