Commits

Antonio Cuni committed 9db0af7

add support for calling the @jitdriver.inline()d function multiple times

Comments (0)

Files changed (2)

pypy/jit/metainterp/test/test_warmspot.py

         assert res == 1000 + 1002
         self.check_resops(int_add=4)
 
-
     def test_jitdriver_inline(self):
         myjitdriver = JitDriver(greens = [], reds = 'auto')
         class MyRange(object):
         self.check_resops(int_eq=2, int_add=4)
         self.check_trace_count(1)
 
+    def test_jitdriver_inline_twice(self):
+        myjitdriver = JitDriver(greens = [], reds = 'auto')
+
+        def jit_merge_point(a, b):
+            myjitdriver.jit_merge_point()
+
+        @myjitdriver.inline(jit_merge_point)
+        def add(a, b):
+            return a+b
+
+        def one(n):
+            res = 0
+            while res < 1000:
+                res = add(n, res)
+            return res
+
+        def two(n):
+            res = 0
+            while res < 2000:
+                res = add(n, res)
+            return res
+
+        def f(n):
+            return one(n) + two(n)
+
+        res = self.meta_interp(f, [1])
+        assert res == 3000
+        self.check_resops(int_add=4)
+        self.check_trace_count(2)
+
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU

pypy/jit/metainterp/warmspot.py

         from pypy.translator.backendopt.inline import (
             get_funcobj, inlinable_static_callers, auto_inlining)
 
+        jmp_calls = {}
+        def get_jmp_call(graph, _inline_jit_merge_point_):
+            # there might be multiple calls to the @inlined function: the
+            # first time we see it, we remove the call to the jit_merge_point
+            # and we remember the corresponding op. Then, we create a new call
+            # to it every time we need a new one (i.e., for each callsite
+            # which becomes a new portal)
+            try:
+                op, jmp_graph = jmp_calls[graph]
+            except KeyError:
+                op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_)
+                jmp_calls[graph] = op, jmp_graph
+            #
+            # clone the op
+            newargs = op.args[:]
+            newresult = Variable()
+            newresult.concretetype = op.result.concretetype
+            op = SpaceOperation(op.opname, newargs, newresult)
+            return op, jmp_graph
+
+        def fish_jmp_call(graph, _inline_jit_merge_point_):
+            # graph is function which has been decorated with
+            # @jitdriver.inline, so its very first op is a call to the
+            # function which contains the actual jit_merge_point: fish it!
+            jmp_block, op_jmp_call = next(callee.iterblockops())
+            msg = ("The first operation of an _inline_jit_merge_point_ graph must be "
+                   "a direct_call to the function passed to @jitdriver.inline()")
+            assert op_jmp_call.opname == 'direct_call', msg
+            jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
+            assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+            jmp_block.operations.remove(op_jmp_call)
+            return op_jmp_call, jmp_funcobj.graph
+
         # find all the graphs which call an @inline_in_portal function
         callgraph = inlinable_static_callers(self.translator.graphs, store_calls=True)
         new_callgraph = []
             func = getattr(callee, 'func', None)
             _inline_jit_merge_point_ = getattr(func, '_inline_jit_merge_point_', None)
             if _inline_jit_merge_point_:
-                # we are calling a function which has been decorated with
-                # @jitdriver.inline: the very first op of the callee graph is
-                # a call to the function which contains the actual
-                # jit_merge_point: fish it!
-                jmp_block, op_jmp_call = next(callee.iterblockops())
-                msg = ("The first operation of an _inline_jit_merge_point_ graph must be "
-                       "a direct_call to the function passed to @jitdriver.inline()")
-                assert op_jmp_call.opname == 'direct_call', msg
-                jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
-                assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+                op_jmp_call, jmp_graph = get_jmp_call(callee, _inline_jit_merge_point_)
                 #
                 # now we move the op_jmp_call from callee to caller, just
                 # before op_call. We assume that the args passed to
                 # op_jmp_call are the very same which are received by callee
                 # (i.e., the one passed to op_call)
                 assert len(op_call.args) == len(op_jmp_call.args)
-                jmp_block.operations.remove(op_jmp_call)
                 op_jmp_call.args[1:] = op_call.args[1:]
                 idx = block.operations.index(op_call)
                 block.operations.insert(idx, op_jmp_call)
                 # finally, we signal that we want to inline op_jmp_call into
                 # caller, so that finally the actuall call to
                 # driver.jit_merge_point will be seen there
-                new_callgraph.append((caller, jmp_funcobj.graph))
+                new_callgraph.append((caller, jmp_graph))
                 new_portals.add(caller)
 
         # inline them!