Commits

Alex Gaynor committed a7c9fe6

(fijal, alex, greg): When you have 2 jitdrivers, if the first one gets into a function of the second one, but the trace gets too long, don't find a function inside the second jitdriver as the longest function.

Comments (0)

Files changed (4)

pypy/jit/metainterp/history.py

 
     def forget_value(self):
         self.value = 0
-        
+
     def clonebox(self):
         return BoxInt(self.value)
 
     def add_new_loop(self, loop):
         pass
 
+    def record_aborted(self, greenkey):
+        pass
+
     def view(self, **kwds):
         pass
 
     def __init__(self):
         self.loops = []
         self.locations = []
+        self.aborted_keys = []
 
     def set_history(self, history):
         self.history = history
     def add_new_loop(self, loop):
         self.loops.append(loop)
 
+    def record_aborted(self, greenkey):
+        self.aborted_keys.append(greenkey)
+
     # test read interface
 
     def get_all_loops(self):

pypy/jit/metainterp/pyjitpl.py

                                jcposition, redboxes):
         resumedescr = compile.ResumeAtPositionDescr()
         self.capture_resumedata(resumedescr, orgpc)
-        
+
         any_operation = len(self.metainterp.history.operations) > 0
         jitdriver_sd = self.metainterp.staticdata.jitdrivers_sd[jdindex]
         self.verify_green_args(jitdriver_sd, greenboxes)
             "found a loop_header for a JitDriver that does not match "
             "the following jit_merge_point's")
         self.metainterp.seen_loop_header_for_jdindex = -1
-        
+
         #
         if not self.metainterp.in_recursion:
             assert jitdriver_sd is self.metainterp.jitdriver_sd
         f.setup_call(boxes)
         raise ChangeFrame
 
+    def is_main_jitcode(self, jitcode):
+        return self.jitdriver_sd is not None and jitcode is self.jitdriver_sd.mainjitcode
+
     def newframe(self, jitcode, greenkey=None):
         if jitcode.is_portal:
             self.in_recursion += 1
-        if greenkey is not None:
+        if greenkey is not None and self.is_main_jitcode(jitcode):
             self.portal_trace_positions.append(
                     (greenkey, len(self.history.operations)))
         if len(self.free_frames_list) > 0:
 
     def popframe(self):
         frame = self.framestack.pop()
-        if frame.jitcode.is_portal:
+        jitcode = frame.jitcode
+        if jitcode.is_portal:
             self.in_recursion -= 1
-        if frame.greenkey is not None:
+        if frame.greenkey is not None and self.is_main_jitcode(jitcode):
             self.portal_trace_positions.append(
                     (None, len(self.history.operations)))
         # we save the freed MIFrames to avoid needing to re-create new
         warmrunnerstate = self.jitdriver_sd.warmstate
         if len(self.history.operations) > warmrunnerstate.trace_limit:
             greenkey_of_huge_function = self.find_biggest_function()
+            self.staticdata.stats.record_aborted(greenkey_of_huge_function)
             self.portal_trace_positions = None
             if greenkey_of_huge_function is not None:
                 warmrunnerstate.disable_noninlinable_function(
             dont_change_position = True
         else:
             dont_change_position = False
-        try:            
+        try:
             self.prepare_resume_from_failure(key.guard_opnum, dont_change_position)
             if self.resumekey_original_loop_token is None:   # very rare case
                 raise SwitchToBlackhole(ABORT_BRIDGE)
 
         self.history.inputargs = original_inputargs
         self.history.operations = self.history.operations[:start]
-        
+
         self.history.record(rop.JUMP, bridge_arg_boxes[num_green_args:], None)
         try:
             target_loop_token = compile.compile_new_bridge(self,

pypy/jit/metainterp/test/test_jitdriver.py

 """Tests for multiple JitDrivers."""
-from pypy.rlib.jit import JitDriver
+from pypy.rlib.jit import JitDriver, unroll_safe
 from pypy.jit.metainterp.test.test_basic import LLJitMixin, OOJitMixin
+from pypy.jit.metainterp.warmspot import get_stats
 
 
 def getloc1():
         # we expect no int_sub, but a residual call
         self.check_loops(int_sub=0, call=1)
 
+    def test_multiple_jits_trace_too_long(self):
+        myjitdriver1 = JitDriver(greens=["n"], reds=["i", "box"])
+        myjitdriver2 = JitDriver(greens=["n"], reds=["i"])
+
+        class IntBox(object):
+            def __init__(self, val):
+                self.val = val
+
+        def loop1(n):
+            i = 0
+            box = IntBox(10)
+            while i < n:
+                myjitdriver1.can_enter_jit(n=n, i=i, box=box)
+                myjitdriver1.jit_merge_point(n=n, i=i, box=box)
+                i += 1
+                loop2(box)
+            return i
+
+        def loop2(n):
+            i = 0
+            f(10)
+            while i < n.val:
+                myjitdriver2.can_enter_jit(n=n, i=i)
+                myjitdriver2.jit_merge_point(n=n, i=i)
+                i += 1
+
+        @unroll_safe
+        def f(n):
+            i = 0
+            while i < n:
+                i += 1
+
+        res = self.meta_interp(loop1, [10], inline=True, trace_limit=6)
+        assert res == 10
+        stats = get_stats()
+        assert stats.aborted_keys == [None, None]
+
 
 class TestLLtype(MultipleJitDriversTests, LLJitMixin):
     pass

pypy/jit/metainterp/test/test_pyjitpl.py

     class FakeStaticData:
         cpu = None
         warmrunnerdesc = None
+        mainjitcode = portal
 
-    metainterp = pyjitpl.MetaInterp(FakeStaticData(), None)
+    metainterp = pyjitpl.MetaInterp(FakeStaticData(), FakeStaticData())
     metainterp.framestack = []
     class FakeHistory:
         operations = []