Commits

Antonio Cuni committed 0104863

implement @jitdriver.inline() and test that it's correctly recognized by warmspot

Comments (0)

Files changed (4)

pypy/jit/metainterp/test/test_warmspot.py

         # test that the machinery to inline jit_merge_points in callers
         # works. The final user does not need to mess manually with the
         # _inline_jit_merge_point_ attribute and similar, it is all nicely
-        # handled by @JitDriver.inline()
+        # handled by @JitDriver.inline() (see next tests)
         myjitdriver = JitDriver(greens = ['a'], reds = 'auto')
 
         def jit_merge_point(a, b):
         self.check_resops(int_add=4)
 
 
-    def test_inline_in_portal(self):
-        py.test.skip('in-progress')
+    def test_jitdriver_inline(self):
         myjitdriver = JitDriver(greens = [], reds = 'auto')
         class MyRange(object):
             def __init__(self, n):
             def __iter__(self):
                 return self
 
-            @myjitdriver.inline_in_portal
+            def jit_merge_point(self):
+                myjitdriver.jit_merge_point()
+
+            @myjitdriver.inline(jit_merge_point)
             def next(self):
-                myjitdriver.jit_merge_point()
                 if self.cur == self.n:
                     raise StopIteration
                 self.cur += 1
                 return self.cur
 
-        def one():
+        def f(n):
             res = 0
-            for i in MyRange(10):
+            for i in MyRange(n):
                 res += i
             return res
 
-        def two():
-            res = 0
-            for i in MyRange(13):
-                res += i * 2
-            return res
-
-        def f(n, m):
-            res = one() * 100
-            res += two()
-            return res
-        expected = f(21, 5)
-        res = self.meta_interp(f, [21, 5])
+        expected = f(21)
+        res = self.meta_interp(f, [21])
         assert res == expected
-        self.check_resops(int_eq=4, int_add=8)
-        self.check_trace_count(2)
+        self.check_resops(int_eq=2, int_add=4)
+        self.check_trace_count(1)
 
 
 class TestLLWarmspot(WarmspotTests, LLJitMixin):
 from pypy.rlib.objectmodel import CDefinedIntSymbolic, keepalive_until_here, specialize
 from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.extregistry import ExtRegistryEntry
+from pypy.tool.sourcetools import rpython_wrapper
 
 DEBUG_ELIDABLE_FUNCTIONS = False
 
     active = True          # if set to False, this JitDriver is ignored
     virtualizables = []
     name = 'jitdriver'
-    inlined_in_portal = False
+    inline_jit_merge_point = False
 
     def __init__(self, greens=None, reds=None, virtualizables=None,
                  get_jitcell_at=None, set_jitcell_at=None,
         # special-cased by ExtRegistryEntry
         pass
 
-    def inline_in_portal(self, func):
-        assert self.autoreds, "inline_in_portal works only with reds='auto'"
-        func._inline_in_portal_ = True
-        func._always_inline_ = True
-        self.inlined_in_portal = True
-        return func
+    def inline(self, call_jit_merge_point):
+        assert self.autoreds, "@inline works only with reds='auto'"
+        self.inline_jit_merge_point = True
+        def decorate(func):
+            template = """
+                def {name}({arglist}):
+                    {call_jit_merge_point}({arglist})
+                    return {original}({arglist})
+            """
+            templateargs = {'call_jit_merge_point': call_jit_merge_point.__name__}
+            globaldict = {call_jit_merge_point.__name__: call_jit_merge_point}
+            result = rpython_wrapper(func, template, templateargs, **globaldict)
+            result._inline_jit_merge_point_ = call_jit_merge_point
+            return result
+
+        return decorate
+        
 
     def clone(self):
         assert self.inline_jit_merge_point, 'JitDriver.clone works only after @inline'

pypy/rlib/test/test_jit.py

     assert driver.reds == ['a', 'b']
     assert driver.numreds == 2
 
+def test_jitdriver_inline():
+    driver = JitDriver(greens=[], reds='auto')
+    calls = []
+    def foo(a, b):
+        calls.append(('foo', a, b))
+
+    @driver.inline(foo)
+    def bar(a, b):
+        calls.append(('bar', a, b))
+        return a+b
+
+    assert bar._inline_jit_merge_point_ is foo
+    assert driver.inline_jit_merge_point
+    assert bar(40, 2) == 42
+    assert calls == [
+        ('foo', 40, 2),
+        ('bar', 40, 2),
+        ]
+
 def test_jitdriver_clone():
-    def foo():
-        pass
+    def bar(): pass
+    def foo(): pass
     driver = JitDriver(greens=[], reds=[])
-    py.test.raises(AssertionError, "driver.inline_in_portal(foo)")
+    py.test.raises(AssertionError, "driver.inline(bar)(foo)")
     #
     driver = JitDriver(greens=[], reds='auto')
     py.test.raises(AssertionError, "driver.clone()")
-    foo = driver.inline_in_portal(foo)
-    assert foo._inline_in_portal_ == True
-    assert foo._always_inline_ == True
+    foo = driver.inline(bar)(foo)
+    assert foo._inline_jit_merge_point_ == bar
     #
     driver.foo = 'bar'
     driver2 = driver.clone()

pypy/tool/sourcetools.py

     return "(%s:%d)%s" % (mod or '?', firstlineno, name or 'UNKNOWN')
 
 
-def rpython_wrapper(f, template, **globaldict):
+def rpython_wrapper(f, template, templateargs=None, **globaldict):
     """  
     We cannot simply wrap the function using *args, **kwds, because it's not
     RPython. Instead, we generate a function from ``template`` with exactly
     the same argument list.
     """
+    if templateargs is None:
+        templateargs = {}
     srcargs, srcvarargs, srckeywords, defaults = inspect.getargspec(f)
     assert not srcvarargs, '*args not supported by enforceargs'
     assert not srckeywords, '**kwargs not supported by enforceargs'
     #
     arglist = ', '.join(srcargs)
-    src = template.format(name=f.func_name, arglist=arglist,
-                          original=f.func_name+'_original')
+    templateargs.update(name=f.func_name,
+                        arglist=arglist,
+                        original=f.func_name+'_original')
+    src = template.format(**templateargs)
     src = py.code.Source(src)
     #
     globaldict[f.func_name + '_original'] = f
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.