Commits

Armin Rigo committed dfd6ba6

Change the ThreadLocalRef to not be a cache any more. Now the GC should
correctly follow and update the references we store there.

Comments (0)

Files changed (6)

rpython/rlib/rstm.py

 # ____________________________________________________________
 
 class ThreadLocalReference(object):
-    _ALL = weakref.WeakKeyDictionary()
     _COUNT = 0
 
     def __init__(self, Cls):
         "NOT_RPYTHON: must be prebuilt"
         self.Cls = Cls
-        self.local = thread._local()
+        self.local = thread._local()      # <- NOT_RPYTHON
         self.unique_id = ThreadLocalReference._COUNT
         ThreadLocalReference._COUNT += 1
-        ThreadLocalReference._ALL[self] = True
 
     def _freeze_(self):
         return True
     @specialize.arg(0)
     def get(self):
         if we_are_translated():
-            ptr = llop.stm_threadlocalref_get(llmemory.Address, self.unique_id)
-            ptr = rffi.cast(rclass.OBJECTPTR, ptr)
+            ptr = llop.stm_threadlocalref_get(rclass.OBJECTPTR, self.unique_id)
             return cast_base_ptr_to_instance(self.Cls, ptr)
         else:
             return getattr(self.local, 'value', None)
         assert isinstance(value, self.Cls) or value is None
         if we_are_translated():
             ptr = cast_instance_to_base_ptr(value)
-            ptr = rffi.cast(llmemory.Address, ptr)
             llop.stm_threadlocalref_set(lltype.Void, self.unique_id, ptr)
         else:
             self.local.value = value
-
-    @staticmethod
-    def flush_all_in_this_thread():
-        if we_are_translated():
-            # NB. this line is repeated in stmtls.py
-            llop.stm_threadlocalref_flush(lltype.Void)
-        else:
-            for tlref in ThreadLocalReference._ALL.keys():
-                tlref.local.value = None

rpython/rlib/test/test_rstm.py

         x = FooBar()
         results.append(t.get() is None)
         t.set(x)
+        results.append(t.get() is x)
         time.sleep(0.2)
         results.append(t.get() is x)
-        ThreadLocalReference.flush_all_in_this_thread()
-        results.append(t.get() is None)
     for i in range(5):
         thread.start_new_thread(subthread, ())
     time.sleep(0.5)

rpython/rtyper/lltypesystem/lloperation.py

     #'stm_jit_invoke_code':    LLOp(canmallocgc=True),
     'stm_threadlocalref_get': LLOp(sideeffects=False),
     'stm_threadlocalref_set': LLOp(),
-    'stm_threadlocalref_flush': LLOp(),
+    'stm_threadlocalref_count': LLOp(sideeffects=False),
+    'stm_threadlocalref_addr':  LLOp(sideeffects=False),
 
     # __________ address operations __________
 

rpython/rtyper/memory/gc/stmtls.py

         #
         debug_start("gc-local")
         #
-        # First clear all thread-local caches, because they might
-        # contain pointers to objects that are about to move.
-        llop.stm_threadlocalref_flush(lltype.Void)
-        #
         if end_of_transaction:
             self.detect_flag_combination = GCFLAG_LOCAL_COPY | GCFLAG_VISITED
         else:
         # Find the roots that are living in raw structures.
         self.collect_from_raw_structures()
         #
+        # Find the roots in the THREADLOCALREF structure.
+        self.collect_from_threadlocalref()
+        #
         # Also find the roots that are the local copy of global objects.
         self.collect_roots_from_tldict()
         #
         self.gc.root_walker.walk_current_nongc_roots(
             StmGCTLS._trace_drag_out1, self)
 
+    def collect_from_threadlocalref(self):
+        if not we_are_translated():
+            return
+        i = llop.stm_threadlocalref_count(lltype.Signed)
+        while i > 0:
+            i -= 1
+            root = llop.stm_threadlocalref_addr(llmemory.Address, i)
+            self._trace_drag_out(root, None)
+
     def trace_and_drag_out_of_nursery(self, obj):
         # This is called to fix the references inside 'obj', to ensure that
         # they are global.  If necessary, the referenced objects are copied

rpython/translator/stm/test/test_ztranslated.py

 import py
 from rpython.rlib import rstm, rgc
+from rpython.rtyper.lltypesystem import lltype, llmemory
+from rpython.rtyper.lltypesystem.lloperation import llop
+from rpython.rtyper.annlowlevel import cast_instance_to_base_ptr
 from rpython.translator.stm.test.support import NoGcCompiledSTMTests
 from rpython.translator.stm.test.support import CompiledSTMTests
 from rpython.translator.stm.test import targetdemo2
             assert t.get() is None
             t.set(x)
             assert t.get() is x
-            rstm.ThreadLocalReference.flush_all_in_this_thread()
-            assert t.get() is None
+            assert llop.stm_threadlocalref_count(lltype.Signed) == 1
+            p = llop.stm_threadlocalref_addr(llmemory.Address, 0)
+            adr = p.address[0]
+            adr2 = cast_instance_to_base_ptr(x)
+            adr2 = llmemory.cast_ptr_to_adr(adr2)
+            assert adr == adr2
             print "ok"
             return 0
         t, cbuilder = self.compile(main)

rpython/translator/stm/threadlocalref.py

                     ids.add(op.args[0].value)
     #
     ids = sorted(ids)
-    fields = [('ptr%d' % id1, llmemory.Address) for id1 in ids]
-    kwds = {'hints': {'stm_thread_local': True}}
-    S = lltype.Struct('THREADLOCALREF', *fields, **kwds)
+    ARRAY = lltype.FixedSizeArray(llmemory.Address, len(ids))
+    S = lltype.Struct('THREADLOCALREF', ('ptr', ARRAY),
+                      hints={'stm_thread_local': True})
     ll_threadlocalref = lltype.malloc(S, immortal=True)
     c_threadlocalref = Constant(ll_threadlocalref, lltype.Ptr(S))
-    c_fieldnames = {}
-    for id1 in ids:
-        fieldname = 'ptr%d' % id1
-        c_fieldnames[id1] = Constant(fieldname, lltype.Void)
+    c_fieldname = Constant('ptr', lltype.Void)
     c_null = Constant(llmemory.NULL, llmemory.Address)
     #
     for graph in graphs:
             for i in range(len(block.operations)-1, -1, -1):
                 op = block.operations[i]
                 if op.opname == 'stm_threadlocalref_set':
-                    id1 = op.args[0].value
-                    op = SpaceOperation('setfield', [c_threadlocalref,
-                                                     c_fieldnames[id1],
-                                                     op.args[1]],
-                                        op.result)
+                    v_array = varoftype(lltype.Ptr(ARRAY))
+                    ops = [
+                        SpaceOperation('getfield', [c_threadlocalref,
+                                                    c_fieldname],
+                                       v_array),
+                        SpaceOperation('setarrayitem', [v_array,
+                                                        op.args[0],
+                                                        op.args[1]],
+                                       op.result)]
+                    block.operations[i:i+1] = ops
+                elif op.opname == 'stm_threadlocalref_get':
+                    v_array = varoftype(lltype.Ptr(ARRAY))
+                    ops = [
+                        SpaceOperation('getfield', [c_threadlocalref,
+                                                    c_fieldname],
+                                       v_array),
+                        SpaceOperation('getarrayitem', [v_array,
+                                                        op.args[0]],
+                                       op.result)]
+                    block.operations[i:i+1] = ops
+                elif op.opname == 'stm_threadlocalref_addr':
+                    v_array = varoftype(lltype.Ptr(ARRAY))
+                    ops = [
+                        SpaceOperation('getfield', [c_threadlocalref,
+                                                    c_fieldname],
+                                       v_array),
+                        SpaceOperation('direct_ptradd', [v_array,
+                                                         op.args[0]],
+                                       op.result)]
+                    block.operations[i:i+1] = ops
+                elif op.opname == 'stm_threadlocalref_count':
+                    c_count = Constant(len(ids), lltype.Signed)
+                    op = SpaceOperation('same_as', [c_count], op.result)
                     block.operations[i] = op
-                elif op.opname == 'stm_threadlocalref_get':
-                    id1 = op.args[0].value
-                    op = SpaceOperation('getfield', [c_threadlocalref,
-                                                     c_fieldnames[id1]],
-                                        op.result)
-                    block.operations[i] = op
-                elif op.opname == 'stm_threadlocalref_flush':
-                    extra = []
-                    for id1 in ids:
-                        op = SpaceOperation('setfield', [c_threadlocalref,
-                                                         c_fieldnames[id1],
-                                                         c_null],
-                                            varoftype(lltype.Void))
-                        extra.append(op)
-                    block.operations[i:i+1] = extra