Commits

Armin Rigo committed c1461e8

Implement stack checks at the beginning of the entry code for
call_assembler. A test passes on 32-bit, still crashes on 64-bit...

Comments (0)

Files changed (8)

pypy/jit/backend/llsupport/llmodel.py

         self.pos_exception = pos_exception
         self.pos_exc_value = pos_exc_value
         self.save_exception = save_exception
+        self.insert_stack_check = lambda: (0, 0, 0)
 
 
     def _setup_exception_handling_translated(self):
             # in the assignment to self.saved_exc_value, as needed.
             self.saved_exc_value = exc_value
 
+        from pypy.rlib import rstack
+        STACK_CHECK_SLOWPATH = lltype.Ptr(lltype.FuncType([lltype.Signed],
+                                                          lltype.Void))
+        def insert_stack_check():
+            startaddr = rstack._stack_get_start_adr()
+            length = rstack._stack_get_length()
+            f = llhelper(STACK_CHECK_SLOWPATH, rstack.stack_check_slowpath)
+            slowpathaddr = rffi.cast(lltype.Signed, f)
+            return startaddr, length, slowpathaddr
+
         self.pos_exception = pos_exception
         self.pos_exc_value = pos_exc_value
         self.save_exception = save_exception
+        self.insert_stack_check = insert_stack_check
 
     def _setup_on_leave_jitted_untranslated(self):
         # assume we don't need a backend leave in this case

pypy/jit/backend/model.py

     done_with_this_frame_int_v = -1
     done_with_this_frame_ref_v = -1
     done_with_this_frame_float_v = -1
+    exit_frame_with_exception_v = -1
     total_compiled_loops = 0
     total_compiled_bridges = 0
     total_freed_loops = 0

pypy/jit/backend/x86/assembler.py

         self.fail_boxes_count = 0
         self._current_depths_cache = (0, 0)
         self.datablockwrapper = None
+        self.stack_check_slowpath_imm = imm0
         self.teardown()
 
     def leave_jitted_hook(self):
             self._build_float_constants()
         if hasattr(gc_ll_descr, 'get_malloc_fixedsize_slowpath_addr'):
             self._build_malloc_fixedsize_slowpath()
+        self._build_stack_check_slowpath()
         debug_start('jit-backend-counts')
         self.set_debug(have_debug_prints())
         debug_stop('jit-backend-counts')
         rawstart = mc.materialize(self.cpu.asmmemmgr, [])
         self.malloc_fixedsize_slowpath2 = rawstart
 
+    _STACK_CHECK_SLOWPATH = lltype.Ptr(lltype.FuncType([lltype.Signed],
+                                                       lltype.Void))
+    def _build_stack_check_slowpath(self):
+        from pypy.rlib import rstack
+        mc = codebuf.MachineCodeBlockWrapper()
+        mc.PUSH_r(ebp.value)
+        mc.MOV_rr(ebp.value, esp.value)
+        #
+        if IS_X86_64:
+            # on the x86_64, we have to save all the registers that may
+            # have been used to pass arguments
+            for reg in [edi, esi, edx, ecx, r8, r9]:
+                mc.PUSH_r(reg.value)
+            mc.SUB_ri(esp.value, 8*8)
+            for i in range(8):
+                mc.MOVSD_sx(8*i, i)     # xmm0 to xmm7
+        #
+        if IS_X86_32:
+            mc.LEA_rb(eax.value, +8)
+            mc.PUSH_r(eax.value)
+        elif IS_X86_64:
+            mc.LEA_rb(edi.value, +16)
+            mc.AND_ri(esp.value, -16)
+        #
+        f = llhelper(self._STACK_CHECK_SLOWPATH, rstack.stack_check_slowpath)
+        addr = rffi.cast(lltype.Signed, f)
+        mc.CALL(imm(addr))
+        #
+        mc.MOV(eax, heap(self.cpu.pos_exception()))
+        mc.TEST_rr(eax.value, eax.value)
+        mc.J_il8(rx86.Conditions['NZ'], 0)
+        jnz_location = mc.get_relative_pos()
+        #
+        if IS_X86_64:
+            # restore the registers
+            for i in range(7, -1, -1):
+                mc.MOVSD_xs(i, 8*i)
+            for i, reg in [(6, r9), (5, r8), (4, ecx),
+                           (3, edx), (2, esi), (1, edi)]:
+                mc.MOV_rb(reg, -8*i)
+        #
+        mc.MOV_rr(esp.value, ebp.value)
+        mc.POP_r(ebp.value)
+        mc.RET()
+        #
+        # patch the JNZ above
+        offset = mc.get_relative_pos() - jnz_location
+        assert 0 < offset <= 127
+        mc.overwrite(jnz_location-1, chr(offset))
+        # clear the exception from the global position
+        mc.MOV(eax, heap(self.cpu.pos_exc_value()))
+        mc.MOV(heap(self.cpu.pos_exception()), imm0)
+        mc.MOV(heap(self.cpu.pos_exc_value()), imm0)
+        # save the current exception instance into fail_boxes_ptr[0]
+        adr = self.fail_boxes_ptr.get_addr_for_num(0)
+        mc.MOV(heap(adr), eax)
+        # call the helper function to set the GC flag on the fail_boxes_ptr
+        # array (note that there is no exception any more here)
+        addr = self.cpu.get_on_leave_jitted_int(save_exception=False)
+        mc.CALL(imm(addr))
+        #
+        assert self.cpu.exit_frame_with_exception_v >= 0
+        mc.MOV_ri(eax.value, self.cpu.exit_frame_with_exception_v)
+        #
+        # footer -- note the ADD, which skips the return address of this
+        # function, and will instead return to the caller's caller.  Note
+        # also that we completely ignore the saved arguments, because we
+        # are interrupting the function.
+        mc.MOV_rr(esp.value, ebp.value)
+        mc.POP_r(ebp.value)
+        mc.ADD_ri(esp.value, WORD)
+        mc.RET()
+        #
+        rawstart = mc.materialize(self.cpu.asmmemmgr, [])
+        self.stack_check_slowpath_imm = imm(rawstart)
+
     def assemble_loop(self, inputargs, operations, looptoken, log):
         '''adds the following attributes to looptoken:
                _x86_loop_code       (an integer giving an address)
         for regloc in self.cpu.CALLEE_SAVE_REGISTERS:
             self.mc.PUSH_r(regloc.value)
 
+    def _call_header_with_stack_check(self):
+        startaddr, length, slowpathaddr = self.cpu.insert_stack_check()
+        if slowpathaddr == 0:
+            pass                # no stack check (e.g. not translated)
+        else:
+            self.mc.MOV(eax, esp)                       # MOV eax, current
+            self.mc.SUB(eax, heap(startaddr))           # SUB eax, [startaddr]
+            self.mc.CMP(eax, imm(length))               # CMP eax, length
+            self.mc.J_il8(rx86.Conditions['B'], 0)      # JB .skip
+            jb_location = self.mc.get_relative_pos()
+            self.mc.CALL(self.stack_check_slowpath_imm) # CALL slowpath
+            # patch the JB above                        # .skip:
+            offset = self.mc.get_relative_pos() - jb_location
+            assert 0 < offset <= 127
+            self.mc.overwrite(jb_location-1, chr(offset))
+            #
+        self._call_header()
+
     def _call_footer(self):
         self.mc.LEA_rb(esp.value, -len(self.cpu.CALLEE_SAVE_REGISTERS) * WORD)
 
         # XXX this can be improved greatly. Right now it'll behave like
         #     a normal call
         nonfloatlocs, floatlocs = arglocs
-        self._call_header()
+        self._call_header_with_stack_check()
         self.mc.LEA_rb(esp.value, self._get_offset_of_ebp_from_esp(stackdepth))
         for i in range(len(nonfloatlocs)):
             loc = nonfloatlocs[i]
         unused_xmm = [xmm7, xmm6, xmm5, xmm4, xmm3, xmm2, xmm1, xmm0]
 
         nonfloatlocs, floatlocs = arglocs
-        self._call_header()
+        self._call_header_with_stack_check()
         self.mc.LEA_rb(esp.value, self._get_offset_of_ebp_from_esp(stackdepth))
 
         # The lists are padded with Nones

pypy/jit/backend/x86/regalloc.py

             return StackLoc(i, get_ebp_ofs(i), 1, box_type)
 
 class RegAlloc(object):
-    exc = False
 
     def __init__(self, assembler, translate_support_code=False):
         assert isinstance(translate_support_code, bool)
         locs = [self.loc(op.getarg(i)) for i in range(op.numargs())]
         locs_are_ref = [op.getarg(i).type == REF for i in range(op.numargs())]
         fail_index = self.assembler.cpu.get_fail_descr_number(op.getdescr())
-        self.assembler.generate_failure(fail_index, locs, self.exc,
+        # note: no exception should currently be set in llop.get_exception_addr
+        # even if this finish may be an exit_frame_with_exception (in this case
+        # the exception instance is in locs[0]).
+        self.assembler.generate_failure(fail_index, locs, False,
                                         locs_are_ref)
         self.possibly_free_vars_for_op(op)
 

pypy/jit/backend/x86/rx86.py

     INSN_rr = insn(rex_w, chr(base+1), register(2,8), register(1,1), '\xC0')
     INSN_br = insn(rex_w, chr(base+1), register(2,8), stack_bp(1))
     INSN_rb = insn(rex_w, chr(base+3), register(1,8), stack_bp(2))
+    INSN_rj = insn(rex_w, chr(base+3), register(1,8), '\x05', immediate(2))
     INSN_bi8 = insn(rex_w, '\x83', orbyte(base), stack_bp(1), immediate(2,'b'))
     INSN_bi32= insn(rex_w, '\x81', orbyte(base), stack_bp(1), immediate(2))
 
             INSN_bi32(mc, offset, immed)
     INSN_bi._always_inline_ = True      # try to constant-fold single_byte()
 
-    return INSN_ri, INSN_rr, INSN_rb, INSN_bi, INSN_br
+    return INSN_ri, INSN_rr, INSN_rb, INSN_bi, INSN_br, INSN_rj
 
 def select_8_or_32_bit_immed(insn_8, insn_32):
     def INSN(*args):
 
     # ------------------------------ Arithmetic ------------------------------
 
-    ADD_ri, ADD_rr, ADD_rb, _, _ = common_modes(0)
-    OR_ri,  OR_rr,  OR_rb,  _, _ = common_modes(1)
-    AND_ri, AND_rr, AND_rb, _, _ = common_modes(4)
-    SUB_ri, SUB_rr, SUB_rb, _, _ = common_modes(5)
-    XOR_ri, XOR_rr, XOR_rb, _, _ = common_modes(6)
-    CMP_ri, CMP_rr, CMP_rb, CMP_bi, CMP_br = common_modes(7)
+    ADD_ri, ADD_rr, ADD_rb, _, _, ADD_rj = common_modes(0)
+    OR_ri,  OR_rr,  OR_rb,  _, _, OR_rj  = common_modes(1)
+    AND_ri, AND_rr, AND_rb, _, _, AND_rj = common_modes(4)
+    SUB_ri, SUB_rr, SUB_rb, _, _, SUB_rj = common_modes(5)
+    XOR_ri, XOR_rr, XOR_rb, _, _, XOR_rj = common_modes(6)
+    CMP_ri, CMP_rr, CMP_rb, CMP_bi, CMP_br, CMP_rj = common_modes(7)
 
     CMP_mi8 = insn(rex_w, '\x83', orbyte(7<<3), mem_reg_plus_const(1), immediate(2, 'b'))
     CMP_mi32 = insn(rex_w, '\x81', orbyte(7<<3), mem_reg_plus_const(1), immediate(2))
     CMP_ji8 = insn(rex_w, '\x83', '\x3D', immediate(1), immediate(2, 'b'))
     CMP_ji32 = insn(rex_w, '\x81', '\x3D', immediate(1), immediate(2))
     CMP_ji = select_8_or_32_bit_immed(CMP_ji8, CMP_ji32)
-    CMP_rj = insn(rex_w, '\x3B', register(1, 8), '\x05', immediate(2))
 
     CMP32_mi = insn(rex_nw, '\x81', orbyte(7<<3), mem_reg_plus_const(1), immediate(2))
 

pypy/jit/backend/x86/test/test_ztranslation.py

         assert res == expected
 
     def test_direct_assembler_call_translates(self):
+        """Test CALL_ASSEMBLER and the recursion limit"""
+        from pypy.rlib.rstackovf import StackOverflow
+
         class Thing(object):
             def __init__(self, val):
                 self.val = val
                 i += 1
             return frame.thing.val
 
-        res = self.meta_interp(main, [0], inline=True,
+        driver2 = JitDriver(greens = [], reds = ['n'])
+
+        def main2(bound):
+            try:
+                while portal2(bound) == -bound+1:
+                    bound *= 2
+            except StackOverflow:
+                pass
+            return bound
+
+        def portal2(n):
+            while True:
+                driver2.jit_merge_point(n=n)
+                n -= 1
+                if n <= 0:
+                    return n
+                n = portal2(n)
+        assert portal2(10) == -9
+
+        def mainall(codeno, bound):
+            return main(codeno) + main2(bound)
+
+        res = self.meta_interp(mainall, [0, 1], inline=True,
                                policy=StopAtXPolicy(change))
-        assert res == main(0)
+        print hex(res)
+        assert res & 255 == main(0)
+        bound = res & ~255
+        assert 1024 <= bound <= 131072
+        assert bound & (bound-1) == 0       # a power of two
 
 
 class TestTranslationRemoveTypePtrX86(CCompiledMixin):

pypy/jit/metainterp/pyjitpl.py

             num = self.cpu.get_fail_descr_number(tokens[0].finishdescr)
             setattr(self.cpu, 'done_with_this_frame_%s_v' % name, num)
         #
+        tokens = self.loop_tokens_exit_frame_with_exception_ref
+        num = self.cpu.get_fail_descr_number(tokens[0].finishdescr)
+        self.cpu.exit_frame_with_exception_v = num
+        #
         self.globaldata = MetaInterpGlobalData(self)
 
     def _setup_once(self):

pypy/rlib/rstack.py

 
 from pypy.rlib.objectmodel import we_are_translated
 from pypy.rlib.rarithmetic import r_uint
+from pypy.rlib import rgc
 from pypy.rpython.extregistry import ExtRegistryEntry
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.rpython.lltypesystem.lloperation import llop
         return
     #
     # Else call the slow path
+    stack_check_slowpath(current)
+stack_check._always_inline_ = True
+
+@rgc.no_collect
+def stack_check_slowpath(current):
     if ord(_stack_too_big_slowpath(current)):
-        #
         # Now we are sure that the stack is really too big.  Note that the
         # stack_unwind implementation is different depending on if stackless
         # is enabled. If it is it unwinds the stack, otherwise it simply
         # raises a RuntimeError.
         stack_unwind()
+stack_check_slowpath._dont_inline_ = True
 
 # ____________________________________________________________