Armin Rigo avatar Armin Rigo committed 6522077

in-progress: call stm_transaction_break() from within the jit-produced
code

Comments (0)

Files changed (8)

rpython/jit/backend/llsupport/assembler.py

             self._build_b_slowpath(d, True)
             self._build_b_slowpath(d, False, for_frame=True)
         # only for stm:
-        if hasattr(gc_ll_descr, 'stm_ptr_eq_FUNCPTR'):
+        if gc_ll_descr.stm:
             self._build_ptr_eq_slowpath()
+            self._build_stm_longjmp_callback()
         else:
             self.ptr_eq_slowpath = None
         # only one of those

rpython/jit/backend/llsupport/stmrewrite.py

                 self.always_inevitable = False
                 self.newops.append(op)
                 continue
-            # ----------  jump, finish, other ignored ops  ----------
-            if op.getopnum() in (rop.JUMP,
-                                 rop.FINISH,
+            # ----------  jumps  ----------
+            if op.getopnum() == rop.JUMP:
+                self.newops.append(
+                    ResOperation(rop.STM_TRANSACTION_BREAK, [], None))
+                self.newops.append(op)
+                continue
+            # ----------  finish, other ignored ops  ----------
+            if op.getopnum() in (rop.FINISH,
                                  rop.FORCE_TOKEN,
                                  rop.READ_TIMESTAMP,
                                  rop.MARK_OPAQUE_PTR,

rpython/jit/backend/x86/arch.py

 #        +--------------------+    <== aligned to 16 bytes
 #        |   return address   |
 #        +--------------------+
+#        |   STM resume buf   |    (4 extra words, only with STM)
+#        +--------------------+
 #        |    saved regs      |
 #        +--------------------+
 #        |   scratch          |
     JITFRAME_FIXED_SIZE = 28 # 13 GPR + 15 XMM
 
 assert PASS_ON_MY_FRAME >= 12       # asmgcc needs at least JIT_USE_WORDS + 3
+
+STM_RESUME_BUF = 4

rpython/jit/backend/x86/assembler.py

 from rpython.jit.backend.llsupport.regalloc import (get_scale, valid_addressing_size)
 from rpython.jit.backend.x86.arch import (FRAME_FIXED_SIZE, WORD, IS_X86_64,
                                        JITFRAME_FIXED_SIZE, IS_X86_32,
-                                       PASS_ON_MY_FRAME)
+                                       PASS_ON_MY_FRAME, STM_RESUME_BUF)
 from rpython.jit.backend.x86.regloc import (eax, ecx, edx, ebx, esp, ebp, esi,
     xmm0, xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, r8, r9, r10, r11, edi,
     r12, r13, r14, r15, X86_64_SCRATCH_REG, X86_64_XMM_SCRATCH_REG,
                 rgc.cast_instance_to_gcref(self.cpu.propagate_exception_descr))
         ofs = self.cpu.get_ofs_of_frame_field('jf_descr')
         self.mc.MOV(RawEbpLoc(ofs), imm(propagate_exception_descr))
-        self.mc.MOV_rr(eax.value, ebp.value)
         #
         self._call_footer()
         rawstart = self.mc.materialize(self.cpu.asmmemmgr, [])
 
     def _get_stm_tl(self, adr):
         """Makes 'adr' relative to threadlocal-base if we run in STM. 
-        Before using such a relative address, call 
-        self._stm_tl_segment_prefix_if_necessary."""
+        Before using such a relative address, call _tl_segment_if_stm()."""
         if self.cpu.gc_ll_descr.stm and we_are_translated():
             # only for STM and not during tests
             result = adr - stmtlocal.threadlocal_base()
         in STM and not during testing."""
         if self.cpu.gc_ll_descr.stm and we_are_translated():
             stmtlocal.tl_segment_prefix(mc)
-        
+
     def _build_stack_check_slowpath(self):
         if self.cpu.gc_ll_descr.stm:
             return      # XXX no stack check on STM for now
         else:
             descr.set_b_slowpath(withcards + 2 * withfloats, rawstart)
 
+
+    def _build_stm_longjmp_callback(self):
+        assert self.cpu.gc_ll_descr.stm
+        if not we_are_translated():
+            return    # tests only
+        #
+        # make the stm_longjmp_callback() function, with signature
+        #     void (*longjmp_callback)(void *stm_resume_buffer)
+        mc = codebuf.MachineCodeBlockWrapper()
+        #
+        # 'edi' contains the stm resume buffer, so the new stack
+        # location that we have to enforce is 'edi - FRAME_FIXED_SIZE * WORD'.
+        if IS_X86_32:
+            mc.MOV_rs(edi.value, WORD)      # first argument
+        mc.MOV_rr(esp.value, edi.value)
+        mc.SUB_ri(esp.value, FRAME_FIXED_SIZE * WORD)
+        #
+        # must restore 'ebp' from its saved value in the shadowstack
+        self._reload_frame_if_necessary(mc)
+        #
+        # jump to the place saved in the stm_resume_buffer
+        # (to "HERE" in genop_stm_transaction_break())
+        mc.MOV_rs(eax.value, FRAME_FIXED_SIZE * WORD)
+        mc.PUSH_r(eax.value)
+        mc.JMP_r(eax.value)
+        self.stm_longjmp_callback_addr = mc.materialize(self.cpu.asmmemmgr, [])
+
+
     @rgc.no_release_gil
     def assemble_loop(self, loopname, inputargs, operations, looptoken, log,
                       logger=None):
             frame_depth = max(frame_depth, target_frame_depth)
         return frame_depth
 
+    def _get_whole_frame_size(self):
+        frame_size = FRAME_FIXED_SIZE
+        if self.cpu.gc_ll_descr.stm:
+            frame_size += STM_RESUME_BUF
+        return frame_size
+
     def _call_header(self):
-        self.mc.SUB_ri(esp.value, FRAME_FIXED_SIZE * WORD)
+        self.mc.SUB_ri(esp.value, self._get_whole_frame_size() * WORD)
         self.mc.MOV_sr(PASS_ON_MY_FRAME * WORD, ebp.value)
         if IS_X86_64:
             self.mc.MOV_rr(ebp.value, edi.value)
         else:
-            self.mc.MOV_rs(ebp.value, (FRAME_FIXED_SIZE + 1) * WORD)
+            self.mc.MOV_rs(ebp.value, (self._get_whole_frame_size() + 1) * WORD)
 
         for i, loc in enumerate(self.cpu.CALLEE_SAVE_REGISTERS):
             self.mc.MOV_sr((PASS_ON_MY_FRAME + i + 1) * WORD, loc.value)
             #
 
     def _call_footer(self):
+        if self.cpu.gc_ll_descr.stm and we_are_translated():
+            # call stm_invalidate_jmp_buf(), in case we called
+            # stm_transaction_break() earlier
+            assert IS_X86_64
+            # load the address of the STM_RESUME_BUF
+            self.mc.LEA_rs(edi.value, FRAME_FIXED_SIZE * WORD)
+            fn = stmtlocal.stm_invalidate_jmp_buf_fn
+            self.mc.CALL(imm(self.cpu.cast_ptr_to_int(fn)))
+
+        # the return value is the jitframe
+        self.mc.MOV_rr(eax.value, ebp.value)
+
         gcrootmap = self.cpu.gc_ll_descr.gcrootmap
         if gcrootmap and gcrootmap.is_shadow_stack:
             self._call_footer_shadowstack(gcrootmap)
                            (i + 1 + PASS_ON_MY_FRAME) * WORD)
 
         self.mc.MOV_rs(ebp.value, PASS_ON_MY_FRAME * WORD)
-        self.mc.ADD_ri(esp.value, FRAME_FIXED_SIZE * WORD)
+        self.mc.ADD_ri(esp.value, self._get_whole_frame_size() * WORD)
         self.mc.RET()
 
     def _load_shadowstack_top_in_ebx(self, mc, gcrootmap):
         mc.MOV_br(ofs2, eax.value)
         mc.POP(eax)
         mc.MOV_br(ofs, eax.value)
-        # the return value is the jitframe
-        mc.MOV_rr(eax.value, ebp.value)
 
         self._call_footer()
         rawstart = mc.materialize(self.cpu.asmmemmgr, [])
         assert isinstance(reg, RegLoc)
         self.mc.MOV_rr(reg.value, ebp.value)
 
+    def genop_stm_transaction_break(self, op, arglocs, result_loc):
+        assert self.cpu.gc_ll_descr.stm
+        if not we_are_translated():
+            return     # tests only
+        # "if stm_should_break_transaction()"
+        mc = self.mc
+        fn = stmtlocal.stm_should_break_transaction_fn
+        mc.CALL(imm(self.cpu.cast_ptr_to_int(fn)))
+        mc.TEST8_rr(eax.value, eax.value)
+        mc.J_il8(rx86.Conditions['Z'], 0)
+        jz_location = mc.get_relative_pos()
+        #
+        # call stm_transaction_break() with the address of the
+        # STM_RESUME_BUF and the custom longjmp function
+        mc.LEA_rs(edi.value, FRAME_FIXED_SIZE * WORD)
+        mc.MOV_ri(esi.value, self.stm_longjmp_callback_addr)
+        fn = stmtlocal.stm_transaction_break_fn
+        mc.CALL(imm(self.cpu.cast_ptr_to_int(fn)))
+        #
+        # Fill the stm resume buffer.  Don't do it before the call!
+        # The previous transaction may still be aborted during the call
+        # above, so we need the old content of the buffer!
+        # For now the buffer only contains the address of the resume
+        # point in this piece of code (at "HERE").
+        mc.CALL_l(0)
+        # "HERE"
+        mc.POP_r(eax.value)
+        mc.MOV_sr(FRAME_FIXED_SIZE * WORD, eax.value)
+        #
+        # patch the JZ above
+        offset = mc.get_relative_pos() - jz_location
+        assert 0 < offset <= 127
+        mc.overwrite(jz_location-1, chr(offset))
+
+
 genop_discard_list = [Assembler386.not_implemented_op_discard] * rop._LAST
 genop_list = [Assembler386.not_implemented_op] * rop._LAST
 genop_llong_list = {}

rpython/jit/backend/x86/regalloc.py

      RegisterManager, TempBox, compute_vars_longevity, is_comparison_or_ovf_op)
 from rpython.jit.backend.x86 import rx86
 from rpython.jit.backend.x86.arch import (WORD, JITFRAME_FIXED_SIZE, IS_X86_32,
-    IS_X86_64)
+    IS_X86_64, FRAME_FIXED_SIZE)
 from rpython.jit.backend.x86.jump import remap_frame_layout_mixed
 from rpython.jit.backend.x86.regloc import (FrameLoc, RegLoc, ConstFloatLoc,
     FloatImmedLoc, ImmedLoc, imm, imm0, imm1, ecx, eax, edx, ebx, esi, edi,
                 if isinstance(loc, FrameLoc):
                     self.fm.hint_frame_locations[box] = loc
 
+    def consider_stm_transaction_break(self, op):
+        # XXX use the extra 3 words in the stm resume buffer to save
+        # up to 3 registers, too.  For now we just flush them all.
+        self.xrm.before_call(save_all_regs=1)
+        self.rm.before_call(save_all_regs=1)
+        self.perform(op, [], None)
+
     def consider_jump(self, op):
         assembler = self.assembler
         assert self.jump_target_descr is None

rpython/jit/backend/x86/stmtlocal.py

-from rpython.rtyper.lltypesystem import lltype, rffi
+from rpython.rtyper.lltypesystem import lltype, rffi, llmemory
 from rpython.translator.tool.cbuild import ExternalCompilationInfo
 from rpython.jit.backend.x86.arch import WORD
 
         mc.writechar('\x65')   # %gs:
     else:
         mc.writechar('\x64')   # %fs:
+
+
+# special STM functions called directly by the JIT backend
+stm_should_break_transaction_fn = rffi.llexternal(
+    'stm_should_break_transaction',
+    [], lltype.Bool,
+    sandboxsafe=True, _nowrapper=True, transactionsafe=True)
+stm_transaction_break_fn = rffi.llexternal(
+    'stm_transaction_break',
+    [llmemory.Address, llmemory.Address], lltype.Void,
+    sandboxsafe=True, _nowrapper=True, transactionsafe=True)
+stm_invalidate_jmp_buf_fn = rffi.llexternal(
+    'stm_invalidate_jmp_buf',
+    [llmemory.Address], lltype.Void,
+    sandboxsafe=True, _nowrapper=True, transactionsafe=True)

rpython/jit/metainterp/executor.py

                          rop.CALL_MALLOC_NURSERY_VARSIZE,
                          rop.CALL_MALLOC_NURSERY_VARSIZE_FRAME,
                          rop.LABEL,
+                         rop.STM_TRANSACTION_BREAK,
                          ):      # list of opcodes never executed by pyjitpl
                 continue
             raise AssertionError("missing %r" % (key,))

rpython/jit/metainterp/resoperation.py

     'QUASIIMMUT_FIELD/1d',    # [objptr], descr=SlowMutateDescr
     'RECORD_KNOWN_CLASS/2',   # [objptr, clsptr]
     'KEEPALIVE/1',
+    'STM_TRANSACTION_BREAK/0',
 
     '_CANRAISE_FIRST', # ----- start of can_raise operations -----
     '_CALL_FIRST',
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.