rxe avatar rxe committed b538c2f

(cfbolz, rxe) Add a check so that there shouldn't be any code between can_enter_jit() and jit_merge_point().

Comments (0)

Files changed (5)

rpython/jit/metainterp/test/test_ajit.py

                 if i >= len(bytecode):
                     break
                 op = bytecode[i]
+                i += 1
                 if op == 'j':
                     j += 1
                 elif op == 'c':
 
                 else:
                     return ord(op)
-                i += 1
             return 42
         assert f() == 42
         def g():

rpython/jit/metainterp/test/test_virtualizable.py

             frame = Frame(n, 0)
             somewhere_else.top_frame = frame        # escapes
             frame = hint(frame, access_directly=True)
-            while frame.x > 0:
+            while True:
                 myjitdriver.jit_merge_point(frame=frame, fail=fail)
+                if frame.x <= 0:
+                    break
                 frame.x -= 1
                 if fail or frame.x > 2:
                     frame.y += frame.x

rpython/jit/tl/tl.py

 
         stack = hint(stack, access_directly=True)
 
-        while pc < len(code):
+        while True:
             myjitdriver.jit_merge_point(pc=pc, code=code,
                                         stack=stack, inputarg=inputarg)
+
+            if pc >= len(code):
+                break
+
             opcode = ord(code[pc])
             stack.stackpos = promote(stack.stackpos)
             pc += 1

rpython/rlib/jit.py

     virtualizables = []
     name = 'jitdriver'
     inline_jit_merge_point = False
+    _store_last_enter_jit = None
 
     def __init__(self, greens=None, reds=None, virtualizables=None,
                  get_jitcell_at=None, set_jitcell_at=None,
     def _freeze_(self):
         return True
 
-    def _check_arguments(self, livevars):
+    def _check_arguments(self, livevars, is_merge_point):
         assert set(livevars) == self._somelivevars
         # check heuristically that 'reds' and 'greens' are ordered as
         # the JIT will need them to be: first INTs, then REFs, then
         # FLOATs.
         if len(self._heuristic_order) < len(livevars):
             from rpython.rlib.rarithmetic import (r_singlefloat, r_longlong,
-                                               r_ulonglong, r_uint)
+                                                  r_ulonglong, r_uint)
             added = False
             for var, value in livevars.items():
                 if var not in self._heuristic_order:
                         "must be INTs, REFs, FLOATs; got %r" %
                         (color, allkinds))
 
+        if is_merge_point:
+            if self._store_last_enter_jit:
+                if livevars != self._store_last_enter_jit:
+                    raise JitHintError(
+                        "Bad can_enter_jit() placement: there should *not* "
+                        "be any code in between can_enter_jit() -> jit_merge_point()" )
+                self._store_last_enter_jit = None
+        else:
+            self._store_last_enter_jit = livevars
+
     def jit_merge_point(_self, **livevars):
         # special-cased by ExtRegistryEntry
         if _self.check_untranslated:
-            _self._check_arguments(livevars)
+            _self._check_arguments(livevars, True)
 
     def can_enter_jit(_self, **livevars):
         if _self.autoreds:
             raise TypeError, "Cannot call can_enter_jit on a driver with reds='auto'"
         # special-cased by ExtRegistryEntry
         if _self.check_untranslated:
-            _self._check_arguments(livevars)
+            _self._check_arguments(livevars, False)
 
     def loop_header(self):
         # special-cased by ExtRegistryEntry

rpython/rlib/test/test_jit.py

     assert driver2.foo == 'bar'
 
 
+def test_merge_enter_different():
+    myjitdriver = JitDriver(greens=[], reds=['n'])
+    def fn(n):
+        while n > 0:
+            myjitdriver.jit_merge_point(n=n)
+            myjitdriver.can_enter_jit(n=n)
+            n -= 1
+        return n
+    py.test.raises(JitHintError, fn, 100)
+
+    myjitdriver = JitDriver(greens=['n'], reds=[])
+    py.test.raises(JitHintError, fn, 100)
+
 class TestJIT(BaseRtypingTest):
     def test_hint(self):
         def f():
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.