1. Pypy
  2. Untitled project
  3. pypy

Commits

Remi Meier  committed 8f981d8

WIP introduction of explicit jit_stm_transaction_break_point() placement in
pypy to ensure bytecode instruction atomicity. (also fixes always_inevitable
in stmrewrite.py)

  • Participants
  • Parent commits e5e15da
  • Branches stmgc-c4

Comments (0)

Files changed (19)

File TODO

View file
 ------------------------------------------------------------
 
 POSSIBLE BUG:
+part of this is done: still investigate where transaction
+  breaks are really allowed to happen in the JIT. (JUMP,
+  FINISH, call_footer(), ...)
 investigate if another thread can force a jitframe. Thus,
 making a transaction break *after* a guard_not_forced
 would be wrong, as the force will only be visible after

File pypy/interpreter/pyopcode.py

View file
 from pypy.interpreter.pycode import PyCode, BytecodeCorruption
 from rpython.tool.sourcetools import func_with_new_name
 from rpython.rlib.objectmodel import we_are_translated
-from rpython.rlib import jit, rstackovf
+from rpython.rlib import jit, rstackovf, rstm
 from rpython.rlib.rarithmetic import r_uint, intmask
 from rpython.rlib.debug import check_nonneg
 from pypy.tool.stdlib_opcode import bytecode_spec
+from rpython.rlib.jit import we_are_jitted
 
 def unaryoperation(operationname):
     """NOT_RPYTHON"""
                     stmonly_jitdriver.jit_merge_point(
                         self=self, co_code=co_code,
                         next_instr=next_instr, ec=ec)
+                    # nothing inbetween!
+                    rstm.jit_stm_transaction_break_point(False)
                     self = self._hints_for_stm()
                 next_instr = self.handle_bytecode(co_code, next_instr, ec)
         except ExitFrame:
                 # one of the opcodes in the one of the sequences
                 #    * POP_TOP/LOAD_CONST/RETURN_VALUE
                 #    * POP_TOP/LOAD_FAST/RETURN_VALUE
-                from rpython.rlib import rstm
                 if rstm.should_break_transaction():
                     opcode = ord(co_code[next_instr])
                     if opcode not in (opcodedesc.RETURN_VALUE.index,

File pypy/module/pypyjit/interp_jit.py

View file
 from rpython.tool.pairtype import extendabletype
 from rpython.rlib.rarithmetic import r_uint, intmask
 from rpython.rlib.jit import JitDriver, hint, we_are_jitted, dont_look_inside
-from rpython.rlib import jit
+from rpython.rlib import jit, rstm
 from rpython.rlib.jit import current_trace_length, unroll_parameters
 import pypy.interpreter.pyopcode   # for side-effects
 from pypy.interpreter.error import OperationError, operationerrfmt
                 pypyjitdriver.jit_merge_point(ec=ec,
                     frame=self, next_instr=next_instr, pycode=pycode,
                     is_being_profiled=is_being_profiled)
+                # nothing inbetween!
+                rstm.jit_stm_transaction_break_point(False)
                 co_code = pycode.co_code
                 self.valuestackdepth = hint(self.valuestackdepth, promote=True)
                 next_instr = self.handle_bytecode(co_code, next_instr, ec)
             self.last_instr = intmask(jumpto)
             ec.bytecode_trace(self, decr_by)
             jumpto = r_uint(self.last_instr)
+            rstm.jit_stm_transaction_break_point(True)
         #
         pypyjitdriver.can_enter_jit(frame=self, ec=ec, next_instr=jumpto,
                                     pycode=self.getcode(),

File rpython/jit/backend/llsupport/stmrewrite.py

View file
         debug_start("jit-stmrewrite-ops")
         # overridden method from parent class
         #
-        insert_transaction_break = False
         for op in operations:
             opnum = op.getopnum()
             if not we_are_translated():
                          rop.DEBUG_MERGE_POINT):
                 self.newops.append(op)
                 continue
+            # ----------  transaction breaks  ----------
+            if opnum == rop.STM_TRANSACTION_BREAK:
+                self.emitting_an_operation_that_can_collect()
+                self.next_op_may_be_in_new_transaction()
+                self.newops.append(op)
+                continue
             # ----------  ptr_eq  ----------
             if opnum in (rop.PTR_EQ, rop.INSTANCE_PTR_EQ,
                          rop.PTR_NE, rop.INSTANCE_PTR_NE):
             # ----------  pure operations, guards  ----------
             if op.is_always_pure() or op.is_guard() or op.is_ovf():
                 self.newops.append(op)
-
-                # insert a transaction break after call_release_gil
-                # in order to commit the inevitable transaction following
-                # it immediately
-                if (opnum == rop.GUARD_NOT_FORCED
-                    and insert_transaction_break):
-                    # insert transaction_break after GUARD after calls
-                    self.newops.append(
-                        ResOperation(rop.STM_TRANSACTION_BREAK,
-                                     [ConstInt(0)], None))
-                    insert_transaction_break = False
-                    self.emitting_an_operation_that_can_collect()
-                    self.next_op_may_be_in_new_transaction()
-                else:
-                    assert insert_transaction_break is False
-
                 continue
             # ----------  getfields  ----------
             if opnum in (rop.GETFIELD_GC, rop.GETARRAYITEM_GC,
             if op.is_call():
                 self.emitting_an_operation_that_can_collect()
                 self.next_op_may_be_in_new_transaction()
-
-                if opnum in (rop.CALL_MAY_FORCE, rop.CALL_ASSEMBLER,
-                             rop.CALL_RELEASE_GIL):
-                    # insert more transaction breaks after function
-                    # calls since they are likely to return as
-                    # inevitable transactions
-                    insert_transaction_break = True
-                    
+                                    
                 if opnum == rop.CALL_RELEASE_GIL:
                     # self.fallback_inevitable(op)
                     # is done by assembler._release_gil_shadowstack()
                     # non-transactionsafe and non-releasegil function
                     descr = op.getdescr()
                     assert not descr or isinstance(descr, CallDescr)
+                    
                     if not descr or not descr.get_extra_info() \
                       or descr.get_extra_info().call_needs_inevitable():
                         self.fallback_inevitable(op)
                 self.emitting_an_operation_that_can_collect()
                 self.next_op_may_be_in_new_transaction()
                 
-                self.known_lengths.clear()
-                self.always_inevitable = False
                 self.newops.append(op)
                 continue
             # ----------  jumps  ----------
             if opnum == rop.JUMP:
-                self.newops.append(
-                    ResOperation(rop.STM_TRANSACTION_BREAK,
-                                 [ConstInt(1)], None))
-                # self.emitting_an_operation_that_can_collect()
                 self.newops.append(op)
                 continue
             # ----------  finish, other ignored ops  ----------
             debug_print("fallback for", op.repr())
             #
 
-        # call_XX without guard_not_forced?
-        assert not insert_transaction_break
         debug_stop("jit-stmrewrite-ops")
         return self.newops
 
         self.invalidate_write_categories()
     
     def next_op_may_be_in_new_transaction(self):
+        self.known_lengths.clear() # XXX: check if really necessary or
+                                   # just for labels
         self.known_category.clear()
+        self.always_inevitable = False
 
     def invalidate_write_categories(self):
         for v, c in self.known_category.items():

File rpython/jit/backend/x86/assembler.py

View file
             return    # tests only
 
         """ While arriving on slowpath, we have a gcpattern on stack 0.
-        This function must preserve all registers
+        This function does not have to preserve registers. It expects
+        all registers to be saved in the caller.
         """
         mc = codebuf.MachineCodeBlockWrapper()
         # store the gc pattern
         mc.MOV(dest_addr, X86_64_SCRATCH_REG)
 
         
-    def stm_transaction_break(self, check_type, gcmap):
+    def stm_transaction_break(self, gcmap):
         assert self.cpu.gc_ll_descr.stm
         if not we_are_translated():
             return     # tests only
 
-        # check_type: 0 do a check for inevitable before
-        # doing a check of stm_should_break_transaction().
-        # else, just do stm_should_break_transaction()
         mc = self.mc
-        if check_type == 0:
-            # only check stm_should_break_transaction()
-            # if we are inevitable:
-            nc = self._get_stm_tl(rstm.get_active_adr())
-            self._tl_segment_if_stm(mc)
-            mc.CMP_ji(nc, 1)
-            mc.J_il(rx86.Conditions['Z'], 0xfffff)    # patched later
-            jz_location = mc.get_relative_pos()
-        else:
-            jz_location = 0
-        
         # if stm_should_break_transaction()
         fn = stmtlocal.stm_should_break_transaction_fn
         mc.CALL(imm(self.cpu.cast_ptr_to_int(fn)))
         # CALL break function
         fn = self.stm_transaction_break_path
         mc.CALL(imm(fn))
-        # HERE is the place an aborted transaction retries
+        # ** HERE ** is the place an aborted transaction retries
+        # ebp/frame reloaded by longjmp callback
         #
         # restore regs
         base_ofs = self.cpu.get_baseofs_of_frame_field()
             mc.MOVSD_xb(xr.value, (ofs + xr.value * coeff) * WORD + base_ofs)
         #
         # patch the JZ above
-        if jz_location:
-            offset = mc.get_relative_pos() - jz_location
-            mc.overwrite32(jz_location-4, offset)
         offset = mc.get_relative_pos() - jz_location2
         mc.overwrite32(jz_location2-4, offset)
 

File rpython/jit/backend/x86/regalloc.py

View file
         self.perform_discard(op, [base_loc, ofs_loc, size_loc])
         
     def consider_stm_transaction_break(self, op):
-        check_type_box = op.getarg(0)
-        assert isinstance(check_type_box, ConstInt)
-        check_type = check_type_box.getint()
         #
         # only save regs for the should_break_transaction call
         self.xrm.before_call()
         self.rm.before_call()
         gcmap = self.get_gcmap() # allocate the gcmap *before*
         #
-        self.assembler.stm_transaction_break(check_type, gcmap)
+        self.assembler.stm_transaction_break(gcmap)
         
 
     def consider_jump(self, op):

File rpython/jit/codewriter/jtransform.py

View file
                                           [v], None))
         return ops
 
+    def rewrite_op_jit_stm_transaction_break_point(self, op):
+        if isinstance(op.args[0], Constant):
+            arg = int(op.args[0].value)
+            c_arg = Constant(arg, lltype.Signed)
+        else:
+            log.WARNING("stm_transaction_break_point without const argument, assuming False in %r" % (self.graph,))
+            c_arg = Constant(0, lltype.Signed)
+
+        return SpaceOperation('stm_transaction_break', [c_arg], op.result)
+    
     def rewrite_op_jit_marker(self, op):
         key = op.args[0].value
         jitdriver = op.args[1].value

File rpython/jit/codewriter/test/test_jtransform.py

View file
     assert block.operations[1].result is None
     assert block.exits[0].args == [v1]
 
+def test_jit_stm_transaction_break_point():
+    op = SpaceOperation('jit_stm_transaction_break_point',
+                        [Constant(1, lltype.Signed)], lltype.Void)
+    tr = Transformer()
+    op2 = tr.rewrite_operation(op)
+    assert op2.opname == 'stm_transaction_break'
+    assert op2.args[0].value == 1
+    
 def test_jit_merge_point_1():
     class FakeJitDriverSD:
         index = 42

File rpython/jit/metainterp/blackhole.py

View file
     def bhimpl_ref_isvirtual(x):
         return False
 
+    @arguments("i")
+    def bhimpl_stm_transaction_break(if_there_is_no_other):
+        pass
+    
     # ----------
     # the main hints and recursive calls
 

File rpython/jit/metainterp/executor.py

View file
                          rop.CALL_MALLOC_NURSERY_VARSIZE,
                          rop.CALL_MALLOC_NURSERY_VARSIZE_FRAME,
                          rop.LABEL,
-                         rop.STM_TRANSACTION_BREAK,
                          rop.STM_SET_REVISION_GC,
                          ):      # list of opcodes never executed by pyjitpl
                 continue

File rpython/jit/metainterp/pyjitpl.py

View file
 
 class MIFrame(object):
     debug = False
+    # Write resops corresponding to jitcodes
 
     def __init__(self, metainterp):
         self.metainterp = metainterp
         self.parent_resumedata_frame_info_list = None
         # counter for unrolling inlined loops
         self.unroll_iterations = 1
+        # for stm: placement of stm_break_point
+        self.stm_break_wanted = False
+        self.stm_break_done = False
 
     @specialize.arg(3)
     def copy_constants(self, registers, constants, ConstClass):
             raise AssertionError("bad result box type")
 
     # ------------------------------
-
+    @arguments("int")
+    def opimpl_stm_transaction_break(self, if_there_is_no_other):
+        val = bool(if_there_is_no_other)
+        if (self.stm_break_wanted or (val and not self.stm_break_done)):
+            self.stm_break_done = True
+            self.stm_break_wanted = False
+            if not val:
+                print "did an stm_transaction_break(False)"
+            else:
+                print "did an stm_transaction_break(True)"
+            self.execute(rop.STM_TRANSACTION_BREAK, ConstInt(val))
+        elif not val:
+            print "ignored stm_transaction_break(False)"
+        elif val:
+            print "ignored stm_transaction_break(True)"
+    
     for _opimpl in ['int_add', 'int_sub', 'int_mul', 'int_floordiv', 'int_mod',
                     'int_lt', 'int_le', 'int_eq',
                     'int_ne', 'int_gt', 'int_ge',
             # XXX refactor: direct_libffi_call() is a hack
             if effectinfo.oopspecindex == effectinfo.OS_LIBFFI_CALL:
                 self.metainterp.direct_libffi_call()
+            self.stm_break_wanted = True
             return resbox
         else:
             effect = effectinfo.extraeffect

File rpython/jit/tl/tlc.py

View file
 from rpython.jit.tl import tlopcode
 from rpython.rlib.jit import JitDriver, elidable
 from rpython.rlib.rarithmetic import is_valid_int
+from rpython.rlib import rstm
 
 
 class Obj(object):
             if jitted:
                 myjitdriver.jit_merge_point(frame=frame,
                                             code=code, pc=pc, pool=pool)
+                # nothing inbetween!
+                rstm.jit_stm_transaction_break_point(False)
             opcode = ord(code[pc])
             pc += 1
             stack = frame.stack
                 pc += char2int(code[pc])
                 pc += 1
                 if jitted and old_pc > pc:
+                    rstm.jit_stm_transaction_break_point(True)
                     myjitdriver.can_enter_jit(code=code, pc=pc, frame=frame,
                                               pool=pool)
                 
                     old_pc = pc
                     pc += char2int(code[pc]) + 1
                     if jitted and old_pc > pc:
+                        rstm.jit_stm_transaction_break_point(True)
                         myjitdriver.can_enter_jit(code=code, pc=pc, frame=frame,
                                                   pool=pool)
                 else:
                     old_pc = pc
                     pc += offset
                     if jitted and old_pc > pc:
+                        rstm.jit_stm_transaction_break_point(True)
                         myjitdriver.can_enter_jit(code=code, pc=pc, frame=frame,
                                                   pool=pool)
                         

File rpython/rlib/rstm.py

View file
     addr = llop.stm_get_adr_of_read_barrier_cache(llmemory.Address)
     return rffi.cast(lltype.Signed, addr)
 
+def jit_stm_transaction_break_point(if_there_is_no_other):
+    # if_there_is_no_other means that we use this point only
+    # if there is no other break point in the trace.
+    # If it is False, the point may be used if it comes right
+    # a CALL_RELEASE_GIL
+    pass # specialized below
+    # llop.jit_stm_transaction_break_point(lltype.Void,
+    #                                      if_there_is_no_other)
+
+class JitSTMTransactionBreakPoint(ExtRegistryEntry):
+    _about_ = jit_stm_transaction_break_point
+    def compute_result_annotation(self, arg):
+        from rpython.annotator import model as annmodel
+        return annmodel.s_None
+    def specialize_call(self, hop):
+        [v_arg] = hop.inputargs(lltype.Bool)
+        hop.exception_cannot_occur()
+        return hop.genop('jit_stm_transaction_break_point', [v_arg],
+                         resulttype=lltype.Void)
+    
 @dont_look_inside
 def become_inevitable():
     llop.stm_become_inevitable(lltype.Void)

File rpython/rtyper/llinterp.py

View file
     op_stm_become_inevitable = _stm_not_implemented
     op_stm_stop_all_other_threads = _stm_not_implemented
     op_stm_partial_commit_and_resume_other_threads = _stm_not_implemented
+    op_jit_stm_transaction_break_point = _stm_not_implemented
 
     # __________________________________________________________
     # operations on addresses

File rpython/rtyper/lltypesystem/lloperation.py

View file
     # NOTE: use canmallocgc for all operations that can contain a collection.
     #       that includes all that do 'BecomeInevitable' or otherwise contain
     #       possible GC safe-points! (also sync with stmframework.py)
+    # (some ops like stm_commit_transaction don't need it because there
+    #  must be no gc-var access afterwards anyway)
     'stm_initialize':         LLOp(),
     'stm_finalize':           LLOp(),
     'stm_barrier':            LLOp(sideeffects=False),
     'jit_assembler_call': LLOp(canrun=True,   # similar to an 'indirect_call'
                                canraise=(Exception,),
                                canmallocgc=True),
+    'jit_stm_transaction_break_point' : LLOp(),
 
     # __________ GC operations __________
 

File rpython/translator/c/funcgen.py

View file
     OP_STM_MINOR_COLLECT                = _OP_STM
     OP_STM_CLEAR_EXCEPTION_DATA_ON_ABORT= _OP_STM
     OP_STM_ALLOCATE_NONMOVABLE_INT_ADR  = _OP_STM
+    OP_JIT_STM_TRANSACTION_BREAK_POINT  = _OP_STM
 
     def OP_STM_IGNORED_START(self, op):
         return '/* stm_ignored_start */'

File rpython/translator/stm/funcgen.py

View file
                        sizeof(struct pypy_object0 *));
     '''
 
+def jit_stm_transaction_break_point(funcgen, op):
+    return '/* jit_stm_transaction_break_point */'
+    
 def stm_finalize(funcgen, op):
     return 'stm_finalize();'
 

File rpython/translator/stm/inevitable.py

View file
     'stm_threadlocalref_get', 'stm_threadlocalref_set',
     'stm_threadlocalref_count', 'stm_threadlocalref_addr',
     'jit_assembler_call', 'gc_writebarrier',
-    'shrink_array',
+    'shrink_array', 'jit_stm_transaction_break_point',
     ])
 ALWAYS_ALLOW_OPERATIONS |= set(lloperation.enum_tryfold_ops())
 

File rpython/translator/stm/jitdriver.py

View file
                                       cast_base_ptr_to_instance)
 from rpython.rlib import rstm
 from rpython.tool.sourcetools import compile2
+from rpython.translator.c.support import log
 
-
-def find_jit_merge_point(graph):
+def find_jit_merge_point(graph, relaxed=False):
     found = []
     for block in graph.iterblocks():
         for i in range(len(block.operations)):
             op = block.operations[i]
-            if (op.opname == 'jit_marker' and
-                    op.args[0].value == 'jit_merge_point'):
+            if (op.opname == 'jit_marker'
+                and op.args[0].value == 'jit_merge_point'):
                 jitdriver = op.args[1].value
                 if not jitdriver.autoreds:
-                    found.append((block, i))
+                    if (relaxed
+                        or (i + 1 < len(block.operations)
+                            and block.operations[i+1].opname == 'jit_stm_transaction_break_point')):
+                        found.append((block, i))
+                    else:
+                        log.WARNING("ignoring jitdriver without a transaction break point in %r" % (graph,))
                 else:
-                    from rpython.translator.c.support import log
                     log.WARNING("ignoring jitdriver with autoreds in %r" % (
                         graph,))        # XXX XXX!
+
+    assert len(found) <= 1, "several jit_merge_point's in %r" % (graph,)
     if found:
-        assert len(found) == 1, "several jit_merge_point's in %r" % (graph,)
         return found[0]
     else:
         return None
     #     while 1:               ====>      while 1:
     #         jit_merge_point()   |             if should_break_transaction():
     #         stuff_after         |                 return invoke_stm(..)
-    # ----------------------------+             stuff_after
+    #                             |             stuff_after
+    # ----------------------------+             
     #
     # def invoke_stm(..):
     #     p = new container object
     def rewrite_main_graph(self):
         # add 'should_break_transaction()'
         main_graph = self.main_graph
-        block1, i = find_jit_merge_point(main_graph)
+        block1, i = find_jit_merge_point(main_graph, relaxed=True)
         assert i == len(block1.operations) - 1
         del block1.operations[i]
         blockf = self.add_call_should_break_transaction(block1)
         #
         # change the startblock of callback_graph to point just after the
         # jit_merge_point
-        block1, i = find_jit_merge_point(callback_graph)
+        block1, i = find_jit_merge_point(callback_graph, relaxed=True)
         assert i == len(block1.operations) - 1
         del block1.operations[i]
         [link] = block1.exits