Commits

Antonio Cuni committed 29f51cb Merge

merge again the autoreds branch, which now uses an approach which seems to work for the upcoming space.iteriterable

Comments (0)

Files changed (8)

pypy/jit/metainterp/test/test_warmspot.py

         assert res == expected
         self.check_resops(int_sub=2, int_mul=0, int_add=2)
 
-    def test_inline_in_portal(self):
+    def test_inline_jit_merge_point(self):
+        # test that the machinery to inline jit_merge_points in callers
+        # works. The final user does not need to mess manually with the
+        # _inline_jit_merge_point_ attribute and similar, it is all nicely
+        # handled by @JitDriver.inline() (see next tests)
+        myjitdriver = JitDriver(greens = ['a'], reds = 'auto')
+
+        def jit_merge_point(a, b):
+            myjitdriver.jit_merge_point(a=a)
+
+        def add(a, b):
+            jit_merge_point(a, b)
+            return a+b
+        add._inline_jit_merge_point_ = jit_merge_point
+        myjitdriver.inline_jit_merge_point = True
+
+        def calc(n):
+            res = 0
+            while res < 1000:
+                res = add(n, res)
+            return res
+
+        def f():
+            return calc(1) + calc(3)
+
+        res = self.meta_interp(f, [])
+        assert res == 1000 + 1002
+        self.check_resops(int_add=4)
+
+    def test_jitdriver_inline(self):
         myjitdriver = JitDriver(greens = [], reds = 'auto')
         class MyRange(object):
             def __init__(self, n):
             def __iter__(self):
                 return self
 
-            @myjitdriver.inline_in_portal
+            def jit_merge_point(self):
+                myjitdriver.jit_merge_point()
+
+            @myjitdriver.inline(jit_merge_point)
             def next(self):
-                myjitdriver.jit_merge_point()
                 if self.cur == self.n:
                     raise StopIteration
                 self.cur += 1
                 return self.cur
 
-        def one():
+        def f(n):
             res = 0
-            for i in MyRange(10):
+            for i in MyRange(n):
                 res += i
             return res
 
-        def two():
+        expected = f(21)
+        res = self.meta_interp(f, [21])
+        assert res == expected
+        self.check_resops(int_eq=2, int_add=4)
+        self.check_trace_count(1)
+
+    def test_jitdriver_inline_twice(self):
+        myjitdriver = JitDriver(greens = [], reds = 'auto')
+
+        def jit_merge_point(a, b):
+            myjitdriver.jit_merge_point()
+
+        @myjitdriver.inline(jit_merge_point)
+        def add(a, b):
+            return a+b
+
+        def one(n):
             res = 0
-            for i in MyRange(13):
-                res += i * 2
+            while res < 1000:
+                res = add(n, res)
             return res
 
-        def f(n, m):
-            res = one() * 100
-            res += two()
+        def two(n):
+            res = 0
+            while res < 2000:
+                res = add(n, res)
             return res
-        expected = f(21, 5)
-        res = self.meta_interp(f, [21, 5])
+
+        def f(n):
+            return one(n) + two(n)
+
+        res = self.meta_interp(f, [1])
+        assert res == 3000
+        self.check_resops(int_add=4)
+        self.check_trace_count(2)
+
+    def test_jitdriver_inline_exception(self):
+        # this simulates what happens in a real case scenario: inside the next
+        # we have a call which we cannot inline (e.g. space.next in the case
+        # of W_InterpIterable), but we need to put it in a try/except block.
+        # With the first "inline_in_portal" approach, this case crashed
+        myjitdriver = JitDriver(greens = [], reds = 'auto')
+        
+        def inc(x, n):
+            if x == n:
+                raise OverflowError
+            return x+1
+        inc._dont_inline_ = True
+        
+        class MyRange(object):
+            def __init__(self, n):
+                self.cur = 0
+                self.n = n
+
+            def __iter__(self):
+                return self
+
+            def jit_merge_point(self):
+                myjitdriver.jit_merge_point()
+
+            @myjitdriver.inline(jit_merge_point)
+            def next(self):
+                try:
+                    self.cur = inc(self.cur, self.n)
+                except OverflowError:
+                    raise StopIteration
+                return self.cur
+
+        def f(n):
+            res = 0
+            for i in MyRange(n):
+                res += i
+            return res
+
+        expected = f(21)
+        res = self.meta_interp(f, [21])
         assert res == expected
-        self.check_resops(int_eq=4, int_add=8)
-        self.check_trace_count(2)
+        self.check_resops(int_eq=2, int_add=4)
+        self.check_trace_count(1)
+
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
     CPUClass = runner.LLtypeCPU

pypy/jit/metainterp/warmspot.py

 
     def inline_inlineable_portals(self):
         """
-        Find all the graphs which have been decorated with
-        @jitdriver.inline_in_portal and inline them in the callers, making
-        them JIT portals. Then, create a fresh copy of the jitdriver for each
-        of those new portals, because they cannot share the same one.  See
-        test_ajit::test_inline_in_portal.
+        Find all the graphs which have been decorated with @jitdriver.inline
+        and inline them in the callers, making them JIT portals. Then, create
+        a fresh copy of the jitdriver for each of those new portals, because
+        they cannot share the same one.  See
+        test_ajit::test_inline_jit_merge_point
         """
-        from pypy.translator.backendopt import inline
-        lltype_to_classdef = self.translator.rtyper.lltype_to_classdef_mapping()
-        raise_analyzer = inline.RaiseAnalyzer(self.translator)
-        callgraph = inline.inlinable_static_callers(self.translator.graphs)
+        from pypy.translator.backendopt.inline import (
+            get_funcobj, inlinable_static_callers, auto_inlining)
+
+        jmp_calls = {}
+        def get_jmp_call(graph, _inline_jit_merge_point_):
+            # there might be multiple calls to the @inlined function: the
+            # first time we see it, we remove the call to the jit_merge_point
+            # and we remember the corresponding op. Then, we create a new call
+            # to it every time we need a new one (i.e., for each callsite
+            # which becomes a new portal)
+            try:
+                op, jmp_graph = jmp_calls[graph]
+            except KeyError:
+                op, jmp_graph = fish_jmp_call(graph, _inline_jit_merge_point_)
+                jmp_calls[graph] = op, jmp_graph
+            #
+            # clone the op
+            newargs = op.args[:]
+            newresult = Variable()
+            newresult.concretetype = op.result.concretetype
+            op = SpaceOperation(op.opname, newargs, newresult)
+            return op, jmp_graph
+
+        def fish_jmp_call(graph, _inline_jit_merge_point_):
+            # graph is function which has been decorated with
+            # @jitdriver.inline, so its very first op is a call to the
+            # function which contains the actual jit_merge_point: fish it!
+            jmp_block, op_jmp_call = next(callee.iterblockops())
+            msg = ("The first operation of an _inline_jit_merge_point_ graph must be "
+                   "a direct_call to the function passed to @jitdriver.inline()")
+            assert op_jmp_call.opname == 'direct_call', msg
+            jmp_funcobj = get_funcobj(op_jmp_call.args[0].value)
+            assert jmp_funcobj._callable is _inline_jit_merge_point_, msg
+            jmp_block.operations.remove(op_jmp_call)
+            return op_jmp_call, jmp_funcobj.graph
+
+        # find all the graphs which call an @inline_in_portal function
+        callgraph = inlinable_static_callers(self.translator.graphs, store_calls=True)
+        new_callgraph = []
         new_portals = set()
-        for caller, callee in callgraph:
+        for caller, block, op_call, callee in callgraph:
             func = getattr(callee, 'func', None)
-            _inline_in_portal_ = getattr(func, '_inline_in_portal_', False)
-            if _inline_in_portal_:
-                count = inline.inline_function(self.translator, callee, caller,
-                                               lltype_to_classdef, raise_analyzer)
-                assert count > 0, ('The function has been decorated with '
-                                   '@inline_in_portal, but it is not possible '
-                                   'to inline it')
+            _inline_jit_merge_point_ = getattr(func, '_inline_jit_merge_point_', None)
+            if _inline_jit_merge_point_:
+                _inline_jit_merge_point_._always_inline_ = True
+                op_jmp_call, jmp_graph = get_jmp_call(callee, _inline_jit_merge_point_)
+                #
+                # now we move the op_jmp_call from callee to caller, just
+                # before op_call. We assume that the args passed to
+                # op_jmp_call are the very same which are received by callee
+                # (i.e., the one passed to op_call)
+                assert len(op_call.args) == len(op_jmp_call.args)
+                op_jmp_call.args[1:] = op_call.args[1:]
+                idx = block.operations.index(op_call)
+                block.operations.insert(idx, op_jmp_call)
+                #
+                # finally, we signal that we want to inline op_jmp_call into
+                # caller, so that finally the actuall call to
+                # driver.jit_merge_point will be seen there
+                new_callgraph.append((caller, jmp_graph))
                 new_portals.add(caller)
+
+        # inline them!
+        inline_threshold = 0.1 # we rely on the _always_inline_ set above
+        auto_inlining(self.translator, inline_threshold, new_callgraph)
+
+        # make a fresh copy of the JitDriver in all newly created
+        # jit_merge_points
         self.clone_inlined_jit_merge_points(new_portals)
 
     def clone_inlined_jit_merge_points(self, graphs):
         for graph, block, pos in find_jit_merge_points(graphs):
             op = block.operations[pos]
             v_driver = op.args[1]
-            new_driver = v_driver.value.clone()
+            driver = v_driver.value
+            if not driver.inline_jit_merge_point:
+                continue
+            new_driver = driver.clone()
             c_new_driver = Constant(new_driver, v_driver.concretetype)
             op.args[1] = c_new_driver
 
                         alive_v.add(op1.result)
                 greens_v = op.args[2:]
                 reds_v = alive_v - set(greens_v)
+                reds_v = [v for v in reds_v if v.concretetype is not lltype.Void]
                 reds_v = support.sort_vars(reds_v)
                 op.args.extend(reds_v)
                 if jitdriver.numreds is None:
 from pypy.rlib.objectmodel import CDefinedIntSymbolic, keepalive_until_here, specialize
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.extregistry import ExtRegistryEntry
+from pypy.tool.sourcetools import rpython_wrapper
 
 DEBUG_ELIDABLE_FUNCTIONS = False
 
     active = True          # if set to False, this JitDriver is ignored
     virtualizables = []
     name = 'jitdriver'
-    inlined_in_portal = False
+    inline_jit_merge_point = False
 
     def __init__(self, greens=None, reds=None, virtualizables=None,
                  get_jitcell_at=None, set_jitcell_at=None,
         # special-cased by ExtRegistryEntry
         pass
 
-    def inline_in_portal(self, func):
-        assert self.autoreds, "inline_in_portal works only with reds='auto'"
-        func._inline_in_portal_ = True
-        self.inlined_in_portal = True
-        return func
+    def inline(self, call_jit_merge_point):
+        assert self.autoreds, "@inline works only with reds='auto'"
+        self.inline_jit_merge_point = True
+        def decorate(func):
+            template = """
+                def {name}({arglist}):
+                    {call_jit_merge_point}({arglist})
+                    return {original}({arglist})
+            """
+            templateargs = {'call_jit_merge_point': call_jit_merge_point.__name__}
+            globaldict = {call_jit_merge_point.__name__: call_jit_merge_point}
+            result = rpython_wrapper(func, template, templateargs, **globaldict)
+            result._inline_jit_merge_point_ = call_jit_merge_point
+            return result
+
+        return decorate
+        
 
     def clone(self):
-        assert self.inlined_in_portal, 'JitDriver.clone works only after @inline_in_portal'
+        assert self.inline_jit_merge_point, 'JitDriver.clone works only after @inline'
         newdriver = object.__new__(self.__class__)
         newdriver.__dict__ = self.__dict__.copy()
         return newdriver

pypy/rlib/objectmodel.py

 import types
 import math
 import inspect
+from pypy.tool.sourcetools import rpython_wrapper
 
 # specialize is a decorator factory for attaching _annspecialcase_
 # attributes to functions: for example
                         f.func_name, srcargs[i], expected_type)
                     raise TypeError, msg
         #
-        # we cannot simply wrap the function using *args, **kwds, because it's
-        # not RPython. Instead, we generate a function with exactly the same
-        # argument list
+        template = """
+            def {name}({arglist}):
+                if not we_are_translated():
+                    typecheck({arglist})    # pypy.rlib.objectmodel
+                return {original}({arglist})
+        """
+        result = rpython_wrapper(f, template,
+                                 typecheck=typecheck,
+                                 we_are_translated=we_are_translated)
+        #
         srcargs, srcvarargs, srckeywords, defaults = inspect.getargspec(f)
         if kwds:
             types = tuple([kwds.get(arg) for arg in srcargs])
         assert len(srcargs) == len(types), (
             'not enough types provided: expected %d, got %d' %
             (len(types), len(srcargs)))
-        assert not srcvarargs, '*args not supported by enforceargs'
-        assert not srckeywords, '**kwargs not supported by enforceargs'
-        #
-        arglist = ', '.join(srcargs)
-        src = py.code.Source("""
-            def %(name)s(%(arglist)s):
-                if not we_are_translated():
-                    typecheck(%(arglist)s)    # pypy.rlib.objectmodel
-                return %(name)s_original(%(arglist)s)
-        """ % dict(name=f.func_name, arglist=arglist))
-        #
-        mydict = {f.func_name + '_original': f,
-                  'typecheck': typecheck,
-                  'we_are_translated': we_are_translated}
-        exec src.compile() in mydict
-        result = mydict[f.func_name]
-        result.func_defaults = f.func_defaults
-        result.func_dict.update(f.func_dict)
         result._annenforceargs_ = types
         return result
     return decorator
 
+
 # ____________________________________________________________
 
 class Symbolic(object):

pypy/rlib/test/test_jit.py

     assert driver.reds == ['a', 'b']
     assert driver.numreds == 2
 
+def test_jitdriver_inline():
+    driver = JitDriver(greens=[], reds='auto')
+    calls = []
+    def foo(a, b):
+        calls.append(('foo', a, b))
+
+    @driver.inline(foo)
+    def bar(a, b):
+        calls.append(('bar', a, b))
+        return a+b
+
+    assert bar._inline_jit_merge_point_ is foo
+    assert driver.inline_jit_merge_point
+    assert bar(40, 2) == 42
+    assert calls == [
+        ('foo', 40, 2),
+        ('bar', 40, 2),
+        ]
+
 def test_jitdriver_clone():
-    def foo():
-        pass
+    def bar(): pass
+    def foo(): pass
     driver = JitDriver(greens=[], reds=[])
-    py.test.raises(AssertionError, "driver.inline_in_portal(foo)")
+    py.test.raises(AssertionError, "driver.inline(bar)(foo)")
     #
     driver = JitDriver(greens=[], reds='auto')
     py.test.raises(AssertionError, "driver.clone()")
-    foo = driver.inline_in_portal(foo)
-    assert foo._inline_in_portal_ == True
+    foo = driver.inline(bar)(foo)
+    assert foo._inline_jit_merge_point_ == bar
     #
     driver.foo = 'bar'
     driver2 = driver.clone()

pypy/tool/sourcetools.py

     except AttributeError:
         firstlineno = -1
     return "(%s:%d)%s" % (mod or '?', firstlineno, name or 'UNKNOWN')
+
+
+def rpython_wrapper(f, template, templateargs=None, **globaldict):
+    """  
+    We cannot simply wrap the function using *args, **kwds, because it's not
+    RPython. Instead, we generate a function from ``template`` with exactly
+    the same argument list.
+    """
+    if templateargs is None:
+        templateargs = {}
+    srcargs, srcvarargs, srckeywords, defaults = inspect.getargspec(f)
+    assert not srcvarargs, '*args not supported by enforceargs'
+    assert not srckeywords, '**kwargs not supported by enforceargs'
+    #
+    arglist = ', '.join(srcargs)
+    templateargs.update(name=f.func_name,
+                        arglist=arglist,
+                        original=f.func_name+'_original')
+    src = template.format(**templateargs)
+    src = py.code.Source(src)
+    #
+    globaldict[f.func_name + '_original'] = f
+    exec src.compile() in globaldict
+    result = globaldict[f.func_name]
+    result.func_defaults = f.func_defaults
+    result.func_dict.update(f.func_dict)
+    return result

pypy/tool/test/test_sourcetools.py

-from pypy.tool.sourcetools import func_with_new_name, func_renamer
+from pypy.tool.sourcetools import func_with_new_name, func_renamer, rpython_wrapper
 
 def test_rename():
     def f(x, y=5):
     bar3 = func_with_new_name(bar, 'bar3')
     assert bar3.func_doc == 'new doc'
     assert bar2.func_doc != bar3.func_doc
+
+
+def test_rpython_wrapper():
+    calls = []
+
+    def bar(a, b):
+        calls.append(('bar', a, b))
+        return a+b
+
+    template = """
+        def {name}({arglist}):
+            calls.append(('decorated', {arglist}))
+            return {original}({arglist})
+    """
+    bar = rpython_wrapper(bar, template, calls=calls)
+    assert bar(40, 2) == 42
+    assert calls == [
+        ('decorated', 40, 2),
+        ('bar', 40, 2),
+        ]
+
+        

pypy/translator/backendopt/inline.py

     return (0.9999 * measure_median_execution_cost(graph) +
             count), True
 
-def inlinable_static_callers(graphs):
+def inlinable_static_callers(graphs, store_calls=False):
     ok_to_call = set(graphs)
     result = []
+    def add(parentgraph, block, op, graph):
+        if store_calls:
+            result.append((parentgraph, block, op, graph))
+        else:
+            result.append((parentgraph, graph))
+    #
     for parentgraph in graphs:
         for block in parentgraph.iterblocks():
             for op in block.operations:
                         if getattr(getattr(funcobj, '_callable', None),
                                    '_dont_inline_', False):
                             continue
-                        result.append((parentgraph, graph))
+                        add(parentgraph, block, op, graph)
                 if op.opname == "oosend":
                     meth = get_meth_from_oosend(op)
                     graph = getattr(meth, 'graph', None)
                     if graph is not None and graph in ok_to_call:
-                        result.append((parentgraph, graph))
+                        add(parentgraph, block, op, graph)
     return result
     
 def instrument_inline_candidates(graphs, threshold):