1. cgerum
  2. pypy

Commits

Alex Gaynor  committed 184d747

Implement this mostly. unroll_if does simple python code gen, and isconstant has an oopspec so it's recognized by jtransform, which creates a spaceop for it. Pretty simple all in all.

  • Participants
  • Parent commits 6159c9a
  • Branches unroll-if-alt

Comments (0)

Files changed (5)

File pypy/jit/codewriter/jtransform.py

View file
  • Ignore whitespace
             return SpaceOperation('%s_assert_green' % kind, args, None)
         elif oopspec_name == 'jit.current_trace_length':
             return SpaceOperation('current_trace_length', [], op.result)
+        elif oopspec_name == 'jit.isconstant':
+            kind = getkind(args[0].concretetype)
+            return SpaceOperation('%s_isconstant' % kind, args, op.result)
         else:
             raise AssertionError("missing support for %r" % oopspec_name)
 

File pypy/jit/metainterp/blackhole.py

View file
  • Ignore whitespace
     def bhimpl_current_trace_length():
         return -1
 
+    @arguments("i", returns="i")
+    def bhimpl_int_isconstant(x):
+        return False
+
     # ----------
     # the main hints and recursive calls
 

File pypy/jit/metainterp/pyjitpl.py

View file
  • Ignore whitespace
         return ConstInt(trace_length)
 
     @arguments("box")
+    def _opimpl_isconstant(self, box):
+        return ConstInt(isinstance(box, Const))
+
+    opimpl_int_isconstant = _opimpl_isconstant
+
+    @arguments("box")
     def opimpl_virtual_ref(self, box):
         # Details on the content of metainterp.virtualref_boxes:
         #

File pypy/jit/metainterp/test/test_ajit.py

View file
  • Ignore whitespace
+import sys
+
 import py
-import sys
-from pypy.rlib.jit import JitDriver, we_are_jitted, hint, dont_look_inside
-from pypy.rlib.jit import loop_invariant, elidable, promote
-from pypy.rlib.jit import jit_debug, assert_green, AssertGreenFailed
-from pypy.rlib.jit import unroll_safe, current_trace_length
+
+from pypy import conftest
+from pypy.jit.codewriter.policy import JitPolicy, StopAtXPolicy
 from pypy.jit.metainterp import pyjitpl, history
+from pypy.jit.metainterp.optimizeopt import ALL_OPTS_DICT
+from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin, noConst
+from pypy.jit.metainterp.typesystem import LLTypeHelper, OOTypeHelper
+from pypy.jit.metainterp.warmspot import get_stats
 from pypy.jit.metainterp.warmstate import set_future_value
-from pypy.jit.metainterp.warmspot import get_stats
-from pypy.jit.codewriter.policy import JitPolicy, StopAtXPolicy
-from pypy import conftest
+from pypy.rlib.jit import (JitDriver, we_are_jitted, hint, dont_look_inside,
+    loop_invariant, elidable, promote, jit_debug, assert_green,
+    AssertGreenFailed, unroll_safe, current_trace_length, unroll_if, isconstant)
 from pypy.rlib.rarithmetic import ovfcheck
-from pypy.jit.metainterp.typesystem import LLTypeHelper, OOTypeHelper
 from pypy.rpython.lltypesystem import lltype, llmemory, rffi
 from pypy.rpython.ootypesystem import ootype
-from pypy.jit.metainterp.optimizeopt import ALL_OPTS_DICT
-from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin, noConst
+
 
 class BasicTests:
-
     def test_basic(self):
         def f(x, y):
             return x + y
 
         self.meta_interp(f, [10], repeat=3)
 
+    def test_unroll_if_const(self):
+        @unroll_if(lambda arg: isconstant(arg))
+        def f(arg):
+            s = 0
+            while arg > 0:
+                s += arg
+                arg -= 1
+            return s
+
+        driver = JitDriver(greens = ['code'], reds = ['n', 'arg', 's'])
+
+        def main(code, n, arg):
+            s = 0
+            while n > 0:
+                driver.jit_merge_point(code=code, n=n, arg=arg, s=s)
+                if code == 0:
+                    s += f(arg)
+                else:
+                    s += f(1)
+                n -= 1
+            return s
+
+        res = self.meta_interp(main, [0, 10, 2], enable_opts='')
+        assert res == main(0, 10, 2)
+        self.check_loops(call=1)
+        res = self.meta_interp(main, [1, 10, 2], enable_opts='')
+        assert res == main(1, 10, 2)
+        self.check_loops(call=0)
+
+
 class TestLLtype(BaseLLtypeTests, LLJitMixin):
     pass

File pypy/rlib/jit.py

View file
  • Ignore whitespace
+import sys
+
 import py
-import sys
+
+from pypy.rlib.nonconst import NonConstant
+from pypy.rlib.objectmodel import CDefinedIntSymbolic, keepalive_until_here, specialize
+from pypy.rlib.unroll import unrolling_iterable
 from pypy.rpython.extregistry import ExtRegistryEntry
-from pypy.rlib.objectmodel import CDefinedIntSymbolic
-from pypy.rlib.objectmodel import keepalive_until_here, specialize
-from pypy.rlib.unroll import unrolling_iterable
-from pypy.rlib.nonconst import NonConstant
+from pypy.tool.sourcetools import func_with_new_name
+
 
 def elidable(func):
     """ Decorate a function as "trace-elidable". This means precisely that:
     func._jit_loop_invariant_ = True
     return func
 
+def _get_args(func):
+    import inspect
+
+    args, varargs, varkw, defaults = inspect.getargspec(func)
+    args = ["v%s" % (i, ) for i in range(len(args))]
+    assert varargs is None and varkw is None
+    assert not defaults
+    return args
+
 def elidable_promote(promote_args='all'):
     """ A decorator that promotes all arguments and then calls the supplied
     function
     """
     def decorator(func):
-        import inspect
         elidable(func)
-        args, varargs, varkw, defaults = inspect.getargspec(func)
-        args = ["v%s" % (i, ) for i in range(len(args))]
-        assert varargs is None and varkw is None
-        assert not defaults
+        args = _get_args(func)
         argstring = ", ".join(args)
         code = ["def f(%s):\n" % (argstring, )]
         if promote_args != 'all':
     warnings.warn("purefunction_promote is deprecated, use elidable_promote instead", DeprecationWarning)
     return elidable_promote(*args, **kwargs)
 
+def unroll_if(predicate):
+    def inner(func):
+        args = _get_args(func)
+        argstring = ", ".join(args)
+        d = {
+            "func": func,
+            "func_unroll": unroll_safe(func_with_new_name(func, func.__name__ + "_unroll")),
+            "predicate": predicate,
+        }
+        exec py.code.Source("""
+            def f(%(argstring)s):
+                if predicate(%(argstring)s):
+                    return func_unroll(%(argstring)s)
+                else:
+                    return func(%(argstring)s)
+        """ % {"argstring": argstring}).compile() in d
+        result = d["f"]
+        result.func_name = func.func_name + "_unroll_if"
+        return result
+    return inner
 
 def oopspec(spec):
     def decorator(func):
         return func
     return decorator
 
+@oopspec("jit.isconstant(value)")
+def isconstant(value):
+    """
+    While tracing, returns whether or not the value is currently known to be
+    constant. This is not perfect, values can become constant later. Mostly for
+    use with @unroll_if.
+    """
+    # I hate the annotator so much.
+    if NonConstant(False):
+        return True
+    return False
+
 class Entry(ExtRegistryEntry):
     _about_ = hint
 
 
     def specialize_call(self, hop):
         pass
-    
+
 vref_None = non_virtual_ref(None)
 
 # ____________________________________________________________
     """Inconsistency in the JIT hints."""
 
 PARAMETERS = {'threshold': 1032, # just above 1024
-              'function_threshold': 1617, # slightly more than one above 
+              'function_threshold': 1617, # slightly more than one above
               'trace_eagerness': 200,
               'trace_limit': 12000,
               'inlining': 1,
                             raise
     set_user_param._annspecialcase_ = 'specialize:arg(0)'
 
-    
+
     def on_compile(self, logger, looptoken, operations, type, *greenargs):
         """ A hook called when loop is compiled. Overwrite
         for your own jitdriver if you want to do something special, like