Commits

Armin Rigo committed d33ea92

Support undoing changes done to the thread-local structure in case of
aborts.

  • Participants
  • Parent commits b441cc9
  • Branches stm-thread-2

Comments (0)

Files changed (7)

File rpython/rtyper/lltypesystem/lloperation.py

     #'stm_jit_invoke_code':    LLOp(canmallocgc=True),
     'stm_threadlocalref_get': LLOp(sideeffects=False),
     'stm_threadlocalref_set': LLOp(),
-    'stm_threadlocalref_count': LLOp(sideeffects=False),
-    'stm_threadlocalref_addr':  LLOp(sideeffects=False),
+    'stm_threadlocalref_llset': LLOp(),
+    'stm_threadlocalref_llcount': LLOp(sideeffects=False),
+    'stm_threadlocalref_lladdr':  LLOp(sideeffects=False),
 
     # __________ address operations __________
 

File rpython/rtyper/memory/gc/stmtls.py

     def collect_from_threadlocalref(self):
         if not we_are_translated():
             return
-        i = llop.stm_threadlocalref_count(lltype.Signed)
+        i = llop.stm_threadlocalref_llcount(lltype.Signed)
         while i > 0:
             i -= 1
-            root = llop.stm_threadlocalref_addr(llmemory.Address, i)
+            root = llop.stm_threadlocalref_lladdr(llmemory.Address, i)
             self._trace_drag_out(root, None)
 
     def trace_and_drag_out_of_nursery(self, obj):

File rpython/translator/stm/inevitable.py

     'jit_record_known_class',
     'gc_identityhash', 'gc_id',
     'gc_adr_of_root_stack_top',
-
     'weakref_create', 'weakref_deref',
+    'stm_threadlocalref_get', 'stm_threadlocalref_set',
+    'stm_threadlocalref_count', 'stm_threadlocalref_addr',
     ])
 ALWAYS_ALLOW_OPERATIONS |= set(lloperation.enum_tryfold_ops())
 

File rpython/translator/stm/src_stm/et.c

   struct GcPtrList list_of_read_objects;
   struct GcPtrList gcroots;
   struct G2L global_to_local;
+  struct GcPtrList undolog;
   struct FXCache recent_reads_cache;
 };
 
 
   CancelLocks(d);
 
+  if (d->undolog.size > 0) {
+      gcptr *item = d->undolog.items;
+      long i;
+      for (i=d->undolog.size; i>=0; i-=2) {
+          void **addr = (void **)(item[i-2]);
+          void *oldvalue = (void *)(item[i-1]);
+          *addr = oldvalue;
+      }
+  }
+
   /* upon abort, set the reads size limit to 94% of how much was read
      so far.  This should ensure that, assuming the retry does the same
      thing, it will commit just before it reaches the conflicting point. */
   assert(!g2l_any_entry(&d->global_to_local));
   d->count_reads = 0;
   fxcache_clear(&d->recent_reads_cache);
+  gcptrlist_clear(&d->undolog);
 }
 
 void BeginTransaction(jmp_buf* buf)
 
 /************************************************************/
 
+void stm_ThreadLocalRef_LLSet(void **addr, void *newvalue)
+{
+  struct tx_descriptor *d = thread_descriptor;
+  gcptrlist_insert2(&d->undolog, (gcptr)addr, (gcptr)*addr);
+  *addr = newvalue;
+}
+
+/************************************************************/
+
 int DescriptorInit(void)
 {
   if (thread_descriptor == NULL)

File rpython/translator/stm/src_stm/et.h

 #define STM_PTR_EQ(P1, P2)                      \
     stm_PtrEq((gcptr)(P1), (gcptr)(P2))
 
+#define OP_STM_THREADLOCALREF_LLSET(P, X, IGNORED)          \
+    stm_ThreadLocalRef_LLSet((void **)(P), (void *)(X))
+
 /* special usage only */
 #define OP_STM_READ_BARRIER(P, R)   R = STM_BARRIER_P2R(P)
 #define OP_STM_WRITE_BARRIER(P, W)   W = STM_BARRIER_P2W(P)
 void *stm_WriteBarrierFromReady(void *);
 //gcptr _NonTransactionalReadBarrier(gcptr);
 
+void stm_ThreadLocalRef_LLSet(void **P, void *X);
+
 
 extern void *pypy_g__stm_duplicate(void *);
 extern void pypy_g__stm_enum_callback(void *, void *);

File rpython/translator/stm/test/test_ztranslated.py

             assert t.get() is None
             t.set(x)
             assert t.get() is x
-            assert llop.stm_threadlocalref_count(lltype.Signed) == 1
-            p = llop.stm_threadlocalref_addr(llmemory.Address, 0)
+            assert llop.stm_threadlocalref_llcount(lltype.Signed) == 1
+            p = llop.stm_threadlocalref_lladdr(llmemory.Address, 0)
             adr = p.address[0]
             adr2 = cast_instance_to_base_ptr(x)
             adr2 = llmemory.cast_ptr_to_adr(adr2)

File rpython/translator/stm/threadlocalref.py

     c_fieldname = Constant('ptr', lltype.Void)
     c_null = Constant(llmemory.NULL, llmemory.Address)
     #
+    def getaddr(v_num, v_result):
+        v_array = varoftype(lltype.Ptr(ARRAY))
+        ops = [
+            SpaceOperation('getfield', [c_threadlocalref, c_fieldname],
+                           v_array),
+            SpaceOperation('direct_ptradd', [v_array, v_num], v_result)]
+        return ops
+    #
     for graph in graphs:
         for block in graph.iterblocks():
             for i in range(len(block.operations)-1, -1, -1):
                 op = block.operations[i]
                 if op.opname == 'stm_threadlocalref_set':
-                    v_array = varoftype(lltype.Ptr(ARRAY))
-                    ops = [
-                        SpaceOperation('getfield', [c_threadlocalref,
-                                                    c_fieldname],
-                                       v_array),
-                        SpaceOperation('setarrayitem', [v_array,
-                                                        op.args[0],
-                                                        op.args[1]],
-                                       op.result)]
+                    v_addr = varoftype(lltype.Ptr(ARRAY))
+                    ops = getaddr(op.args[0], v_addr)
+                    ops.append(SpaceOperation('stm_threadlocalref_llset',
+                                              [v_addr, op.args[1]],
+                                              op.result))
                     block.operations[i:i+1] = ops
                 elif op.opname == 'stm_threadlocalref_get':
                     v_array = varoftype(lltype.Ptr(ARRAY))
                                                         op.args[0]],
                                        op.result)]
                     block.operations[i:i+1] = ops
-                elif op.opname == 'stm_threadlocalref_addr':
-                    v_array = varoftype(lltype.Ptr(ARRAY))
-                    ops = [
-                        SpaceOperation('getfield', [c_threadlocalref,
-                                                    c_fieldname],
-                                       v_array),
-                        SpaceOperation('direct_ptradd', [v_array,
-                                                         op.args[0]],
-                                       op.result)]
-                    block.operations[i:i+1] = ops
-                elif op.opname == 'stm_threadlocalref_count':
+                elif op.opname == 'stm_threadlocalref_lladdr':
+                    block.operations[i:i+1] = getaddr(op.args[0], op.result)
+                elif op.opname == 'stm_threadlocalref_llcount':
                     c_count = Constant(len(ids), lltype.Signed)
                     op = SpaceOperation('same_as', [c_count], op.result)
                     block.operations[i] = op