Maciej Fijalkowski avatar Maciej Fijalkowski committed f121b8a

Change set_param API. Now we run set_param(driver-or-None, 'name', value)
instead of driver.set_param(...). This makes it possible to pass None, which
changes the param on all drivers.

Comments (0)

Files changed (9)

pypy/jit/metainterp/test/test_ajit.py

 from pypy.rlib.jit import (JitDriver, we_are_jitted, hint, dont_look_inside,
     loop_invariant, elidable, promote, jit_debug, assert_green,
     AssertGreenFailed, unroll_safe, current_trace_length, look_inside_iff,
-    isconstant, isvirtual, promote_string)
+    isconstant, isvirtual, promote_string, set_param)
 from pypy.rlib.rarithmetic import ovfcheck
 from pypy.rpython.lltypesystem import lltype, llmemory, rffi
 from pypy.rpython.ootypesystem import ootype
                 n -= 1
                 x += n
             return x
-        def f(n, threshold):
-            myjitdriver.set_param('threshold', threshold)
+        def f(n, threshold, arg):
+            if arg:
+                set_param(myjitdriver, 'threshold', threshold)
+            else:
+                set_param(None, 'threshold', threshold)
             return g(n)
 
-        res = self.meta_interp(f, [10, 3])
+        res = self.meta_interp(f, [10, 3, 1])
         assert res == 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + 1 + 0
         self.check_tree_loop_count(2)
 
-        res = self.meta_interp(f, [10, 13])
+        res = self.meta_interp(f, [10, 13, 0])
         assert res == 9 + 8 + 7 + 6 + 5 + 4 + 3 + 2 + 1 + 0
         self.check_tree_loop_count(0)
 
                                 get_printable_location=get_printable_location)
         bytecode = "0j10jc20a3"
         def f():
-            myjitdriver.set_param('threshold', 7)
-            myjitdriver.set_param('trace_eagerness', 1)
+            set_param(myjitdriver, 'threshold', 7)
+            set_param(myjitdriver, 'trace_eagerness', 1)
             i = j = c = a = 1
             while True:
                 myjitdriver.jit_merge_point(i=i, j=j, c=c, a=a)
         myjitdriver = JitDriver(greens = [], reds = ['n', 'i', 'sa', 'a'])
 
         def f(n, limit):
-            myjitdriver.set_param('retrace_limit', limit)
+            set_param(myjitdriver, 'retrace_limit', limit)
             sa = i = a = 0
             while i < n:
                 myjitdriver.jit_merge_point(n=n, i=i, sa=sa, a=a)
         myjitdriver = JitDriver(greens = [], reds = ['n', 'i', 'sa', 'a'])
 
         def f(n, limit):
-            myjitdriver.set_param('retrace_limit', 3)
-            myjitdriver.set_param('max_retrace_guards', limit)
+            set_param(myjitdriver, 'retrace_limit', 3)
+            set_param(myjitdriver, 'max_retrace_guards', limit)
             sa = i = a = 0
             while i < n:
                 myjitdriver.jit_merge_point(n=n, i=i, sa=sa, a=a)
         myjitdriver = JitDriver(greens = [], reds = ['n', 'i', 'sa', 'a',
                                                      'node'])
         def f(n, limit):
-            myjitdriver.set_param('retrace_limit', limit)
+            set_param(myjitdriver, 'retrace_limit', limit)
             sa = i = a = 0
             node = [1, 2, 3]
             node[1] = n
         myjitdriver = JitDriver(greens = ['pc'], reds = ['n', 'i', 'sa'])
         bytecode = "0+sI0+SI"
         def f(n):
-            myjitdriver.set_param('threshold', 3)
-            myjitdriver.set_param('trace_eagerness', 1)
-            myjitdriver.set_param('retrace_limit', 5)
-            myjitdriver.set_param('function_threshold', -1)
+            set_param(None, 'threshold', 3)
+            set_param(None, 'trace_eagerness', 1)
+            set_param(None, 'retrace_limit', 5)
+            set_param(None, 'function_threshold', -1)
             pc = sa = i = 0
             while pc < len(bytecode):
                 myjitdriver.jit_merge_point(pc=pc, n=n, sa=sa, i=i)
         myjitdriver = JitDriver(greens = ['pc'], reds = ['n', 'a', 'i', 'j', 'sa'])
         bytecode = "ij+Jj+JI"
         def f(n, a):
-            myjitdriver.set_param('threshold', 5)
-            myjitdriver.set_param('trace_eagerness', 1)
-            myjitdriver.set_param('retrace_limit', 2)
+            set_param(None, 'threshold', 5)
+            set_param(None, 'trace_eagerness', 1)
+            set_param(None, 'retrace_limit', 2)
             pc = sa = i = j = 0
             while pc < len(bytecode):
                 myjitdriver.jit_merge_point(pc=pc, n=n, sa=sa, i=i, j=j, a=a)
                 return B(self.val + 1)
         myjitdriver = JitDriver(greens = [], reds = ['sa', 'a'])
         def f():
-            myjitdriver.set_param('threshold', 3)
-            myjitdriver.set_param('trace_eagerness', 2)
+            set_param(None, 'threshold', 3)
+            set_param(None, 'trace_eagerness', 2)
             a = A(0)
             sa = 0
             while a.val < 8:
                 return B(self.val + 1)
         myjitdriver = JitDriver(greens = [], reds = ['sa', 'b', 'a'])
         def f(b):
-            myjitdriver.set_param('threshold', 6)
-            myjitdriver.set_param('trace_eagerness', 4)
+            set_param(None, 'threshold', 6)
+            set_param(None, 'trace_eagerness', 4)
             a = A(0)
             sa = 0
             while a.val < 15:
         myjitdriver = JitDriver(greens = ['pc'], reds = ['n', 'i', 'sa'])
         bytecode = "0+sI0+SI"
         def f(n):
-            myjitdriver.set_param('threshold', 3)
-            myjitdriver.set_param('trace_eagerness', 1)
-            myjitdriver.set_param('retrace_limit', 5)
-            myjitdriver.set_param('function_threshold', -1)
+            set_param(None, 'threshold', 3)
+            set_param(None, 'trace_eagerness', 1)
+            set_param(None, 'retrace_limit', 5)
+            set_param(None, 'function_threshold', -1)
             pc = sa = i = 0
             while pc < len(bytecode):
                 myjitdriver.jit_merge_point(pc=pc, n=n, sa=sa, i=i)

pypy/jit/metainterp/test/test_jitdriver.py

 """Tests for multiple JitDrivers."""
-from pypy.rlib.jit import JitDriver, unroll_safe
+from pypy.rlib.jit import JitDriver, unroll_safe, set_param
 from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin
 from pypy.jit.metainterp.warmspot import get_stats
 
             return n
         #
         def loop2(g, r):
-            myjitdriver1.set_param('function_threshold', 0)
+            set_param(None, 'function_threshold', 0)
             while r > 0:
                 myjitdriver2.can_enter_jit(g=g, r=r)
                 myjitdriver2.jit_merge_point(g=g, r=r)

pypy/jit/metainterp/test/test_loop.py

 import py
-from pypy.rlib.jit import JitDriver, hint
+from pypy.rlib.jit import JitDriver, hint, set_param
 from pypy.rlib.objectmodel import compute_hash
 from pypy.jit.metainterp.warmspot import ll_meta_interp, get_stats
 from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin
         myjitdriver = JitDriver(greens = ['pos'], reds = ['i', 'j', 'n', 'x'])
         bytecode = "IzJxji"
         def f(n, threshold):
-            myjitdriver.set_param('threshold', threshold)        
+            set_param(myjitdriver, 'threshold', threshold)        
             i = j = x = 0
             pos = 0
             op = '-'
         myjitdriver = JitDriver(greens = ['pos'], reds = ['i', 'j', 'n', 'x'])
         bytecode = "IzJxji"
         def f(nval, threshold):
-            myjitdriver.set_param('threshold', threshold)        
+            set_param(myjitdriver, 'threshold', threshold)        
             i, j, x = A(0), A(0), A(0)
             n = A(nval)
             pos = 0

pypy/jit/metainterp/test/test_recursive.py

 import py
-from pypy.rlib.jit import JitDriver, we_are_jitted, hint
+from pypy.rlib.jit import JitDriver, hint, set_param
 from pypy.rlib.jit import unroll_safe, dont_look_inside, promote
 from pypy.rlib.objectmodel import we_are_translated
 from pypy.rlib.debug import fatalerror
                 pc += 1
             return n
         def main(n):
-            myjitdriver.set_param('threshold', 3)
-            myjitdriver.set_param('trace_eagerness', 5)            
+            set_param(None, 'threshold', 3)
+            set_param(None, 'trace_eagerness', 5)            
             return f("c-l", n)
         expected = main(100)
         res = self.meta_interp(main, [100], enable_opts='', inline=True)
                 return recursive(n - 1) + 1
             return 0
         def loop(n):            
-            myjitdriver.set_param("threshold", 10)
+            set_param(myjitdriver, "threshold", 10)
             pc = 0
             while n:
                 myjitdriver.can_enter_jit(n=n)
             return 0
         myjitdriver = JitDriver(greens=[], reds=['n'])
         def loop(n):
-            myjitdriver.set_param("threshold", 4)
-            myjitdriver.set_param("trace_eagerness", 2)
+            set_param(None, "threshold", 4)
+            set_param(None, "trace_eagerness", 2)
             while n:
                 myjitdriver.can_enter_jit(n=n)
                 myjitdriver.jit_merge_point(n=n)
         TRACE_LIMIT = 66
  
         def main(inline):
-            myjitdriver.set_param("threshold", 10)
-            myjitdriver.set_param('function_threshold', 60)
+            set_param(None, "threshold", 10)
+            set_param(None, 'function_threshold', 60)
             if inline:
-                myjitdriver.set_param('inlining', True)
+                set_param(None, 'inlining', True)
             else:
-                myjitdriver.set_param('inlining', False)
+                set_param(None, 'inlining', False)
             return loop(100)
 
         res = self.meta_interp(main, [0], enable_opts='', trace_limit=TRACE_LIMIT)
                 pc += 1
             return n
         def g(m):
-            myjitdriver.set_param('inlining', True)
+            set_param(None, 'inlining', True)
             # carefully chosen threshold to make sure that the inner function
             # cannot be inlined, but the inner function on its own is small
             # enough
-            myjitdriver.set_param('trace_limit', 40)
+            set_param(None, 'trace_limit', 40)
             if m > 1000000:
                 f('', 0)
             result = 0
                     driver.can_enter_jit(c=c, i=i, v=v)
                 break
 
-        def main(c, i, set_param, v):
-            if set_param:
-                driver.set_param('function_threshold', 0)
+        def main(c, i, _set_param, v):
+            if _set_param:
+                set_param(driver, 'function_threshold', 0)
             portal(c, i, v)
 
         self.meta_interp(main, [10, 10, False, False], inline=True)

pypy/jit/metainterp/test/test_warmspot.py

 import py
-from pypy.jit.metainterp.warmspot import ll_meta_interp
 from pypy.jit.metainterp.warmspot import get_stats
-from pypy.rlib.jit import JitDriver
-from pypy.rlib.jit import unroll_safe
+from pypy.rlib.jit import JitDriver, set_param, unroll_safe
 from pypy.jit.backend.llgraph import runner
-from pypy.jit.metainterp.history import BoxInt
 
 from pypy.jit.metainterp.test.support import LLJitMixin, OOJitMixin
 from pypy.jit.metainterp.optimizeopt import ALL_OPTS_NAMES
                 n = A().m(n)
             return n
         def f(n, enable_opts):
-            myjitdriver.set_param('enable_opts', hlstr(enable_opts))
+            set_param(None, 'enable_opts', hlstr(enable_opts))
             return g(n)
 
         # check that the set_param will override the default

pypy/jit/metainterp/test/test_ztranslation.py

 import py
 from pypy.jit.metainterp.warmspot import rpython_ll_meta_interp, ll_meta_interp
 from pypy.jit.backend.llgraph import runner
-from pypy.rlib.jit import JitDriver, unroll_parameters
+from pypy.rlib.jit import JitDriver, unroll_parameters, set_param
 from pypy.rlib.jit import PARAMETERS, dont_look_inside, hint
 from pypy.jit.metainterp.jitprof import Profiler
 from pypy.rpython.lltypesystem import lltype, llmemory
                               get_printable_location=get_printable_location)
         def f(i):
             for param, defl in unroll_parameters:
-                jitdriver.set_param(param, defl)
-            jitdriver.set_param("threshold", 3)
-            jitdriver.set_param("trace_eagerness", 2)
+                set_param(jitdriver, param, defl)
+            set_param(jitdriver, "threshold", 3)
+            set_param(jitdriver, "trace_eagerness", 2)
             total = 0
             frame = Frame(i)
             while frame.l[0] > 3:
                 raise ValueError
             return 2
         def main(i):
-            jitdriver.set_param("threshold", 3)
-            jitdriver.set_param("trace_eagerness", 2)
+            set_param(jitdriver, "threshold", 3)
+            set_param(jitdriver, "trace_eagerness", 2)
             total = 0
             n = i
             while n > 3:

pypy/jit/metainterp/warmspot.py

                 op = block.operations[i]
                 if (op.opname == 'jit_marker' and
                     op.args[0].value == marker_name and
-                    op.args[1].value.active):   # the jitdriver
+                    (op.args[1].value is None or
+                    op.args[1].value.active)):   # the jitdriver
                     results.append((graph, block, i))
     return results
 
         _, PTR_SET_PARAM_STR_FUNCTYPE = self.cpu.ts.get_FuncType(
             [lltype.Ptr(STR)], lltype.Void)
         def make_closure(jd, fullfuncname, is_string):
-            state = jd.warmstate
-            def closure(i):
-                if is_string:
-                    i = hlstr(i)
-                getattr(state, fullfuncname)(i)
+            if jd is None:
+                def closure(i):
+                    if is_string:
+                        i = hlstr(i)
+                    for jd in self.jitdrivers_sd:
+                        getattr(jd.warmstate, fullfuncname)(i)
+            else:
+                state = jd.warmstate
+                def closure(i):
+                    if is_string:
+                        i = hlstr(i)
+                    getattr(state, fullfuncname)(i)
             if is_string:
                 TP = PTR_SET_PARAM_STR_FUNCTYPE
             else:
             return Constant(funcptr, TP)
         #
         for graph, block, i in find_set_param(graphs):
+            
             op = block.operations[i]
-            for jd in self.jitdrivers_sd:
-                if jd.jitdriver is op.args[1].value:
-                    break
+            if op.args[1].value is not None:
+                for jd in self.jitdrivers_sd:
+                    if jd.jitdriver is op.args[1].value:
+                        break
             else:
-                assert 0, "jitdriver of set_param() not found"
+                jd = None
             funcname = op.args[2].value
             key = jd, funcname
             if key not in closures:

pypy/module/pypyjit/interp_jit.py

 from pypy.tool.pairtype import extendabletype
 from pypy.rlib.rarithmetic import r_uint, intmask
 from pypy.rlib.jit import JitDriver, hint, we_are_jitted, dont_look_inside
+from pypy.rlib import jit
 from pypy.rlib.jit import current_trace_length, unroll_parameters
 import pypy.interpreter.pyopcode   # for side-effects
 from pypy.interpreter.error import OperationError, operationerrfmt
     if len(args_w) == 1:
         text = space.str_w(args_w[0])
         try:
-            pypyjitdriver.set_user_param(text)
+            jit.set_user_param(None, text)
         except ValueError:
             raise OperationError(space.w_ValueError,
                                  space.wrap("error in JIT parameters string"))
     for key, w_value in kwds_w.items():
         if key == 'enable_opts':
-            pypyjitdriver.set_param('enable_opts', space.str_w(w_value))
+            jit.set_param(None, 'enable_opts', space.str_w(w_value))
         else:
             intval = space.int_w(w_value)
             for name, _ in unroll_parameters:
                 if name == key and name != 'enable_opts':
-                    pypyjitdriver.set_param(name, intval)
+                    jit.set_param(None, name, intval)
                     break
             else:
                 raise operationerrfmt(space.w_TypeError,
         # special-cased by ExtRegistryEntry
         pass
 
-    def _set_param(self, name, value):
-        # special-cased by ExtRegistryEntry
-        # (internal, must receive a constant 'name')
-        # if value is DEFAULT, sets the default value.
-        assert name in PARAMETERS
-
-    @specialize.arg(0, 1)
-    def set_param(self, name, value):
-        """Set one of the tunable JIT parameter."""
-        self._set_param(name, value)
-
-    @specialize.arg(0, 1)
-    def set_param_to_default(self, name):
-        """Reset one of the tunable JIT parameters to its default value."""
-        self._set_param(name, DEFAULT)
-
-    def set_user_param(self, text):
-        """Set the tunable JIT parameters from a user-supplied string
-        following the format 'param=value,param=value', or 'off' to
-        disable the JIT.  For programmatic setting of parameters, use
-        directly JitDriver.set_param().
-        """
-        if text == 'off':
-            self.set_param('threshold', -1)
-            self.set_param('function_threshold', -1)
-            return
-        if text == 'default':
-            for name1, _ in unroll_parameters:
-                self.set_param_to_default(name1)
-            return
-        for s in text.split(','):
-            s = s.strip(' ')
-            parts = s.split('=')
-            if len(parts) != 2:
-                raise ValueError
-            name = parts[0]
-            value = parts[1]
-            if name == 'enable_opts':
-                self.set_param('enable_opts', value)
-            else:
-                for name1, _ in unroll_parameters:
-                    if name1 == name and name1 != 'enable_opts':
-                        try:
-                            self.set_param(name1, int(value))
-                        except ValueError:
-                            raise
-    set_user_param._annspecialcase_ = 'specialize:arg(0)'
-
-
     def on_compile(self, logger, looptoken, operations, type, *greenargs):
         """ A hook called when loop is compiled. Overwrite
         for your own jitdriver if you want to do something special, like
         self.jit_merge_point = self.jit_merge_point
         self.can_enter_jit = self.can_enter_jit
         self.loop_header = self.loop_header
-        self._set_param = self._set_param
-
         class Entry(ExtEnterLeaveMarker):
             _about_ = (self.jit_merge_point, self.can_enter_jit)
 
         class Entry(ExtLoopHeader):
             _about_ = self.loop_header
 
-        class Entry(ExtSetParam):
-            _about_ = self._set_param
+def _set_param(driver, name, value):
+    # special-cased by ExtRegistryEntry
+    # (internal, must receive a constant 'name')
+    # if value is DEFAULT, sets the default value.
+    assert name in PARAMETERS
+
+@specialize.arg(0, 1)
+def set_param(driver, name, value):
+    """Set one of the tunable JIT parameter. Driver can be None, then all
+    drivers have this set """
+    _set_param(driver, name, value)
+
+@specialize.arg(0, 1)
+def set_param_to_default(driver, name):
+    """Reset one of the tunable JIT parameters to its default value."""
+    _set_param(driver, name, DEFAULT)
+
+def set_user_param(driver, text):
+    """Set the tunable JIT parameters from a user-supplied string
+    following the format 'param=value,param=value', or 'off' to
+    disable the JIT.  For programmatic setting of parameters, use
+    directly JitDriver.set_param().
+    """
+    if text == 'off':
+        set_param(driver, 'threshold', -1)
+        set_param(driver, 'function_threshold', -1)
+        return
+    if text == 'default':
+        for name1, _ in unroll_parameters:
+            set_param_to_default(driver, name1)
+        return
+    for s in text.split(','):
+        s = s.strip(' ')
+        parts = s.split('=')
+        if len(parts) != 2:
+            raise ValueError
+        name = parts[0]
+        value = parts[1]
+        if name == 'enable_opts':
+            set_param(driver, 'enable_opts', value)
+        else:
+            for name1, _ in unroll_parameters:
+                if name1 == name and name1 != 'enable_opts':
+                    try:
+                        set_param(driver, name1, int(value))
+                    except ValueError:
+                        raise
+set_user_param._annspecialcase_ = 'specialize:arg(0)'
+
 
 # ____________________________________________________________
 #
                          resulttype=lltype.Void)
 
 class ExtSetParam(ExtRegistryEntry):
+    _about_ = _set_param
 
-    def compute_result_annotation(self, s_name, s_value):
+    def compute_result_annotation(self, s_driver, s_name, s_value):
         from pypy.annotation import model as annmodel
         assert s_name.is_constant()
         if not self.bookkeeper.immutablevalue(DEFAULT).contains(s_value):
         from pypy.objspace.flow.model import Constant
 
         hop.exception_cannot_occur()
-        driver = self.instance.im_self
-        name = hop.args_s[0].const
+        driver = hop.inputarg(lltype.Void, arg=0)
+        name = hop.args_s[1].const
         if name == 'enable_opts':
             repr = string_repr
         else:
             repr = lltype.Signed
-        if (isinstance(hop.args_v[1], Constant) and
-            hop.args_v[1].value is DEFAULT):
+        if (isinstance(hop.args_v[2], Constant) and
+            hop.args_v[2].value is DEFAULT):
             value = PARAMETERS[name]
             v_value = hop.inputconst(repr, value)
         else:
-            v_value = hop.inputarg(repr, arg=1)
+            v_value = hop.inputarg(repr, arg=2)
         vlist = [hop.inputconst(lltype.Void, "set_param"),
-                 hop.inputconst(lltype.Void, driver),
+                 driver,
                  hop.inputconst(lltype.Void, name),
                  v_value]
         return hop.genop('jit_marker', vlist,
                          resulttype=lltype.Void)
+
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.