Commits

Armin Rigo committed 46ec3e0

In-progress: the goal is to change the GC so that after a collection
the stack variables that were local are still local --- unless they
are specially marked as "stm_local_not_needed" by stm/transform.py.

Comments (0)

Files changed (9)

pypy/interpreter/pyframe.py

 
     # stack manipulation helpers
     def pushvalue(self, w_object):
-        #hint(self, stm_assert_local=True)    XXX re-enable
+        hint(self, stm_assert_local=True)
         depth = self.valuestackdepth
         self.locals_stack_w[depth] = w_object
         self.valuestackdepth = depth + 1
 
     def popvalue(self):
-        #hint(self, stm_assert_local=True)    XXX re-enable
+        hint(self, stm_assert_local=True)
         depth = self.valuestackdepth - 1
         assert depth >= self.pycode.co_nlocals, "pop from empty value stack"
         w_object = self.locals_stack_w[depth]

pypy/interpreter/pyopcode.py

             raise e                   # re-raise the exception we got
 
     def _dispatch_stm_transaction(self, retry_counter):
-        self = self._hints_for_stm()
         try:
             co_code = self.pycode.co_code
             next_instr = r_uint(self.last_instr)
     def dispatch_bytecode(self, co_code, next_instr, ec):
         space = self.space
         while True:
-            self = self._hints_for_stm()
             self.last_instr = intmask(next_instr)
             if not jit.we_are_jitted():
                 ec.bytecode_trace(self)

pypy/rpython/lltypesystem/lloperation.py

     'stm_getinteriorfield':   LLOp(sideeffects=False, canrun=True),
     'stm_become_inevitable':  LLOp(),
     'stm_writebarrier':       LLOp(),
+    'stm_local_not_needed':   LLOp(),
     'stm_normalize_global':   LLOp(),
     'stm_start_transaction':  LLOp(canrun=True, canmallocgc=True),
     'stm_stop_transaction':   LLOp(canrun=True, canmallocgc=True),

pypy/rpython/memory/gctransform/stmframework.py

 
 class StmFrameworkGCTransformer(FrameworkGCTransformer):
 
+    def transform_graph(self, graph):
+        self.vars_local_not_needed = set()
+        super(StmFrameworkGCTransformer, self).transform_graph(graph)
+        self.vars_local_not_needed = None
+
     def _declare_functions(self, GCClass, getfn, s_gc, *args):
         super(StmFrameworkGCTransformer, self)._declare_functions(
             GCClass, getfn, s_gc, *args)
         hop.genop("direct_call", [self.stm_stop_ptr, self.c_const_gc])
         self.pop_roots(hop, livevars)
 
+    def gct_stm_local_not_needed(self, hop):
+        self.vars_local_not_needed.update(hop.spaceop.args)
+
 
 class StmShadowStackRootWalker(BaseRootWalker):
     need_root_stack = True

pypy/translator/stm/gcsource.py

 from pypy.objspace.flow.model import Variable
 from pypy.rpython.lltypesystem import lltype, rclass
 from pypy.translator.simplify import get_graph
-from pypy.translator.unsimplify import split_block
-from pypy.translator.backendopt import graphanalyze
 
 
 COPIES_POINTER = set([
     ])
 
 
-def _is_gc(var_or_const):
+def is_gc(var_or_const):
     TYPE = var_or_const.concretetype
     return isinstance(TYPE, lltype.Ptr) and TYPE.TO._gckind == 'gc'
 
         inputargs = graph.getargs()
         assert len(args) == len(inputargs)
         for v1, v2 in zip(args, inputargs):
-            if _is_gc(v2):
-                assert _is_gc(v1)
+            if is_gc(v2):
+                assert is_gc(v1)
                 resultlist.append((v1, v2))
-        if _is_gc(result):
+        if is_gc(result):
             v = graph.getreturnvar()
-            assert _is_gc(v)
+            assert is_gc(v)
             resultlist.append((v, result))
         was_a_callee.add(graph)
     #
                 if (op.opname in COPIES_POINTER or
                         (op.opname == 'hint' and
                          'stm_write' not in op.args[1].value)):
-                    if _is_gc(op.result) and _is_gc(op.args[0]):
+                    if is_gc(op.result) and is_gc(op.args[0]):
                         resultlist.append((op.args[0], op.result))
                         continue
                 #
                         resultlist.append(('instantiate', op.result))
                         continue
                 #
-                if _is_gc(op.result):
+                if is_gc(op.result):
                     resultlist.append((op, op.result))
             #
             for link in block.exits:
                 for v1, v2 in zip(link.args, link.target.inputargs):
-                    if _is_gc(v2):
-                        assert _is_gc(v1)
+                    if is_gc(v2):
+                        assert is_gc(v1)
                         if v1 is link.last_exc_value:
                             v1 = 'last_exc_value'
                         resultlist.append((v1, v2))
             else:
                 src = 'unknown'
             for v in graph.getargs():
-                if _is_gc(v):
+                if is_gc(v):
                     resultlist.append((src, v))
     return resultlist
 
 
-class TransactionBreakAnalyzer(graphanalyze.BoolGraphAnalyzer):
-    """This analyzer looks for function calls that may ultimately
-    cause a transaction break (end of previous transaction, start
-    of next one)."""
-
-    def analyze_direct_call(self, graph, seen=None):
-        try:
-            func = graph.func
-        except AttributeError:
-            pass
-        else:
-            if getattr(func, '_transaction_break_', False):
-                return True
-        return graphanalyze.GraphAnalyzer.analyze_direct_call(self, graph,
-                                                              seen)
-
-    def analyze_simple_operation(self, op, graphinfo):
-        return op.opname in ('stm_start_transaction',
-                             'stm_stop_transaction')
-
-
-def enum_transactionbroken_vars(translator, transactionbreak_analyzer):
-    if transactionbreak_analyzer is None:
-        return    # for tests only
-    for graph in translator.graphs:
-        for block in graph.iterblocks():
-            if not block.operations:
-                continue
-            for op in block.operations[:-1]:
-                assert not transactionbreak_analyzer.analyze(op)
-            op = block.operations[-1]
-            if not transactionbreak_analyzer.analyze(op):
-                continue
-            # This block ends in a transaction breaking operation.  So
-            # any variable passed from this block to a next one (with
-            # the exception of the variable freshly returned by the
-            # last operation) must be assumed to be potentially global.
-            for link in block.exits:
-                for v1, v2 in zip(link.args, link.target.inputargs):
-                    if v1 is not op.result:
-                        yield v2
-
-def break_blocks_after_transaction_breaker(translator, graph,
-                                           transactionbreak_analyzer):
-    """Split blocks so that they end immediately after any operation
-    that may cause a transaction break."""
-    for block in list(graph.iterblocks()):
-        for i in range(len(block.operations)-2, -1, -1):
-            op = block.operations[i]
-            if transactionbreak_analyzer.analyze(op):
-                split_block(translator.annotator, block, i + 1)
-
-
 class GcSource(object):
     """Works like a dict {gcptr-var: set-of-sources}.  A source is a
     Constant, or a SpaceOperation that creates the value, or a string
     which describes a special case."""
 
-    def __init__(self, translator, transactionbreak_analyzer=None):
+    def __init__(self, translator):
         self.translator = translator
         self._backmapping = {}
         for v1, v2 in enum_gc_dependencies(translator):
             self._backmapping.setdefault(v2, []).append(v1)
-        for v2 in enum_transactionbroken_vars(translator,
-                                              transactionbreak_analyzer):
-            self._backmapping.setdefault(v2, []).append('transactionbreak')
 
     def __getitem__(self, variable):
+        set_of_origins, set_of_variables = self._backpropagate(variable)
+        return set_of_origins
+    
+    def backpropagate(self, variable):
+        set_of_origins, set_of_variables = self._backpropagate(variable)
+        return set_of_variables
+
+    def _backpropagate(self, variable):
         result = set()
         pending = [variable]
         seen = set(pending)
                         pending.append(v1)
                 else:
                     result.add(v1)
-        return result
+        return result, seen

pypy/translator/stm/localtracker.py

     of the stmgc: a pointer is 'local' if it goes to the thread-local memory,
     and 'global' if it points to the shared read-only memory area."""
 
-    def __init__(self, translator, transactionbreak_analyzer=None):
+    def __init__(self, translator):
         self.translator = translator
-        self.gsrc = GcSource(translator, transactionbreak_analyzer)
+        self.gsrc = GcSource(translator)
+        # Set of variables on which we have called try_ensure_local()
+        # and it returned True, or recursively the variables that
+        # these variables depend on.  It is the set of variables
+        # holding a value that we really want to be local.  It does
+        # not contain the variables that happen to be local but whose
+        # locality is not useful any more.
+        self.ensured_local_vars = set()
 
-    def is_local(self, variable):
+    def try_ensure_local(self, *variables):
+        for variable in variables:
+            if not self._could_be_local(variable):
+                return False   # one of the passed-in variables cannot be local
+        #
+        # they could all be locals, so flag them and their dependencies
+        # and return True
+        for variable in variables:
+            if (isinstance(variable, Variable) and
+                    variable not in self.ensured_local_vars):
+                depends_on = self.gsrc.backpropagate(variable)
+                self.ensured_local_vars.update(depends_on)
+        return True
+
+    def _could_be_local(self, variable):
         try:
             srcs = self.gsrc[variable]
         except KeyError:
         return True
 
     def assert_local(self, variable, graph='?'):
-        if self.is_local(variable):
+        if self.try_ensure_local(variable):
             return   # fine
         else:
             raise AssertionError(

pypy/translator/stm/test/test_gcsource.py

 from pypy.translator.translator import TranslationContext
 from pypy.translator.stm.gcsource import GcSource
-from pypy.translator.stm.gcsource import TransactionBreakAnalyzer
-from pypy.translator.stm.gcsource import break_blocks_after_transaction_breaker
 from pypy.objspace.flow.model import SpaceOperation, Constant
 from pypy.rpython.lltypesystem import lltype
 from pypy.rlib.jit import hint
         self.n = n
 
 
-def gcsource(func, sig, transactionbreak=False):
+def gcsource(func, sig):
     t = TranslationContext()
     t.buildannotator().build_types(func, sig)
     t.buildrtyper().specialize()
-    if transactionbreak:
-        transactionbreak_analyzer = TransactionBreakAnalyzer(t)
-        transactionbreak_analyzer.analyze_all()
-        for graph in t.graphs:
-            break_blocks_after_transaction_breaker(
-                t, graph, transactionbreak_analyzer)
-    else:
-        transactionbreak_analyzer = None
-    gsrc = GcSource(t, transactionbreak_analyzer)
+    gsrc = GcSource(t)
     return gsrc
 
 def test_simple():
     s = gsrc[v_result]
     assert len(s) == 1
     assert list(s)[0].opname == 'hint'
-
-def test_transactionbroken():
-    def break_transaction():
-        pass
-    break_transaction._transaction_break_ = True
-    #
-    def main(n):
-        x = X(n)
-        break_transaction()
-        return x
-    gsrc = gcsource(main, [int], transactionbreak=True)
-    v_result = gsrc.translator.graphs[0].getreturnvar()
-    s = gsrc[v_result]
-    assert 'transactionbreak' in s
-    #
-    def main(n):
-        break_transaction()
-        x = X(n)
-        return x
-    gsrc = gcsource(main, [int], transactionbreak=True)
-    v_result = gsrc.translator.graphs[0].getreturnvar()
-    s = gsrc[v_result]
-    assert 'transactionbreak' not in s
-    #
-    def main(n):
-        x = X(n)
-        break_transaction()
-        y = X(n)   # extra operation in the same block
-        return x
-    gsrc = gcsource(main, [int], transactionbreak=True)
-    v_result = gsrc.translator.graphs[0].getreturnvar()
-    s = gsrc[v_result]
-    assert 'transactionbreak' in s
-    #
-    def g(n):
-        break_transaction()
-        return X(n)
-    def main(n):
-        return g(n)
-    gsrc = gcsource(main, [int], transactionbreak=True)
-    v_result = gsrc.translator.graphs[0].getreturnvar()
-    s = gsrc[v_result]
-    assert 'transactionbreak' not in s

pypy/translator/stm/test/test_localtracker.py

     def check(self, expected_names):
         got_local_names = set()
         for name, v in self.translator._seen_locals.items():
-            if self.localtracker.is_local(v):
+            if self.localtracker.try_ensure_local(v):
                 got_local_names.add(name)
                 self.localtracker.assert_local(v, 'foo')
         assert got_local_names == set(expected_names)

pypy/translator/stm/transform.py

 
     def __init__(self, translator=None):
         self.translator = translator
+        self.graph = None
         self.count_get_local     = 0
         self.count_get_nonlocal  = 0
         self.count_get_immutable = 0
     def transform(self):
         assert not hasattr(self.translator, 'stm_transformation_applied')
         self.start_log()
-        t = self.translator
-        transactionbreak_analyzer = gcsource.TransactionBreakAnalyzer(t)
-        transactionbreak_analyzer.analyze_all()
-        #
-        for graph in t.graphs:
-            gcsource.break_blocks_after_transaction_breaker(
-                t, graph, transactionbreak_analyzer)
-        #
-        for graph in t.graphs:
+        for graph in self.translator.graphs:
             pre_insert_stm_writebarrier(graph)
-        #
-        self.localtracker = StmLocalTracker(t, transactionbreak_analyzer)
-        for graph in t.graphs:
+        self.localtracker = StmLocalTracker(self.translator)
+        for graph in self.translator.graphs:
             self.transform_graph(graph)
+        self.make_opnames_cannot_malloc_gc()
+        for graph in self.translator.graphs:
+            self.insert_stm_local_not_needed(graph)
         self.localtracker = None
-        #
         self.translator.stm_transformation_applied = True
         self.print_logs()
 
         self.graph = graph
         for block in graph.iterblocks():
             self.transform_block(block)
-        del self.graph
+        self.graph = None
+
+    # ----------
+
+    def make_opnames_cannot_malloc_gc(self):
+        self.opnames_cannot_malloc_gc = set()
+        for name in lloperation.LL_OPERATIONS:
+            if not getattr(lloperation.llop, name).canmallocgc:
+                self.opnames_cannot_malloc_gc.add(name)
+        self.opnames_cannot_malloc_gc.discard('direct_call')
+        self.opnames_cannot_malloc_gc.discard('indirect_call')
+
+    def insert_stm_local_not_needed(self, graph):
+        # put some 'stm_local_not_needed' operations.  These operations mark
+        # GC pointers that are *not* necessarily locals.  The idea is that
+        # non-marked variables should be considered by the shadowstack code
+        # as "must always be a local", a property enforced during collections.
+        #
+        opnames_cannot_malloc_gc = self.opnames_cannot_malloc_gc
+        ensured_local_vars = self.localtracker.ensured_local_vars
+        #
+        for block in graph.iterblocks():
+            if not block.operations:
+                continue
+            alive = set()
+            newoperationsrev = []
+            for op in reversed(block.operations):
+                newoperationsrev.append(op)
+                alive.discard(op.result)
+                alive.update(op.args)
+                if op.opname not in opnames_cannot_malloc_gc:
+                    vlist = [v for v in alive
+                               if (gcsource.is_gc(v) and
+                                   v not in ensured_local_vars)]
+                    if vlist:
+                        newop = SpaceOperation('stm_local_not_needed', vlist,
+                                               varoftype(lltype.Void))
+                        newoperationsrev.append(newop)
+            block.operations = newoperationsrev[::-1]
 
     # ----------
 
             self.count_get_immutable += 1
             newoperations.append(op)
             return
-        if self.localtracker.is_local(op.args[0]):
+        if self.localtracker.try_ensure_local(op.args[0]):
             self.count_get_local += 1
             newoperations.append(op)
             return
             self.count_set_immutable += 1
             newoperations.append(op)
             return
-        # this is not really a transformation, but just an assertion that
-        # it work on local objects.  This should be ensured by
-        # pre_insert_stm_writebarrier().
-        assert self.localtracker.is_local(op.args[0])
+        # this is not just an assertion that it work on local objects
+        # (which should be ensured by pre_insert_stm_writebarrier()):
+        # it also has the effect of recording in localtracker that we
+        # want this variable to be a local
+        self.localtracker.assert_local(op.args[0], self.graph)
         self.count_set_local += 1
         newoperations.append(op)
 
         self.transform_set(newoperations, op)
 
     def stt_stm_writebarrier(self, newoperations, op):
-        if self.localtracker.is_local(op.args[0]):
+        if self.localtracker.try_ensure_local(op.args[0]):
             op = SpaceOperation('same_as', op.args, op.result)
         else:
-            self.count_write_barrier += 1
+            self.count_write_barrier += 1   # the 'stm_writebarrier' op stays
         newoperations.append(op)
 
     def stt_malloc(self, newoperations, op):
         flags = op.args[1].value
         if flags['flavor'] == 'gc':
-            assert self.localtracker.is_local(op.result)
+            self.localtracker.assert_local(op.result, self.graph)
         else:
             turn_inevitable(newoperations, 'malloc-raw')
         newoperations.append(op)
             self.stt_stm_writebarrier(newoperations, op)
             return
         if 'stm_assert_local' in op.args[1].value:
-            self.localtracker.assert_local(op.args[0],
-                                           getattr(self, 'graph', None))
+            self.localtracker.assert_local(op.args[0], self.graph)
             return
         newoperations.append(op)
 
         if T._gckind == 'raw':
             newoperations.append(op)
             return
-        if (self.localtracker.is_local(op.args[0]) and
-            self.localtracker.is_local(op.args[1])):
+        if self.localtracker.try_ensure_local(op.args[0], op.args[1]):   # both
             newoperations.append(op)
             return
         nargs = []
         for op in block.operations:
             if op.opname in gcsource.COPIES_POINTER:
                 assert len(op.args) == 1
-                if gcsource._is_gc(op.result) and gcsource._is_gc(op.args[0]):
+                if gcsource.is_gc(op.result) and gcsource.is_gc(op.args[0]):
                     copies[op.result] = op
             elif (op.opname in ('getfield', 'getarrayitem',
                                 'getinteriorfield') and