Commits

Armin Rigo committed 838c022

Fixes and completion of R_Container field updates. The test fails so far.

  • Participants
  • Parent commits a59519f

Comments (0)

Files changed (2)

File hack/stm/python/c2.py

 
 def local_object(cpu, **kwds):
     o = Object()
+    cpu.store((o, 'h_tid'), 0)
     cpu.store((o, 'h_revision'), REV_INITIAL)
     for key, value in kwds.items():
         cpu.store((o, key), value)
     # Read/write barriers
     # ---------------------------------------
 
-    def LatestGlobalRevision(self, G):
+    def LatestGlobalRevision(self, G, R_Container, FieldName):
         R = G
         while True:
             while True:
                 self.ValidateDuringTransaction()#try to move start_time forward
                 continue                       # restart searching from R
             break
-        self.PossiblyUpdateChain(G, R)         # see below
+        self.PossiblyUpdateChain(G, R, R_Container, FieldName)    # see below
         return R
 
-    def DirectReadBarrier(self, P):
+    def DirectReadBarrier(self, P, R_Container=None, FieldName=None):
         if not self.h_global(P):         # fast-path
             return P
         if not self.h_possibly_outdated(P):
             R = P
         else:
-            R = self.LatestGlobalRevision(P)
+            R = self.LatestGlobalRevision(P, R_Container, FieldName)
             if self.h_possibly_outdated(R) and R in self.global_to_local:
-                L = self.ReadGlobalToLocal(R)  # see below
+                L = self.ReadGlobalToLocal(R, R_Container, FieldName)#see below
                 return L
         R = self.AddInReadSet(R)                    # see below
         return R
 
-    def RepeatReadBarrier(self, O):
+    def RepeatReadBarrier(self, O, R_Container=None, FieldName=None):
         if not self.h_possibly_outdated(O):       # fast-path
             return O
         # LatestGlobalRevision(O) would either return O or abort
         # the whole transaction, so omitting it is not wrong
         if O in self.global_to_local:
-            L = self.ReadGlobalToLocal(O)         # see below
+            L = self.ReadGlobalToLocal(O, R_Container, FieldName) # see below
             return L
         R = O
         return R
         for attr in R._fields:
             x = self.load((R, attr))
             self.store((L, attr), x)
+            L._fields.add(attr)
         print 'cpu %d: localize done' % (self._cpuindex,)
         self.global_to_local[R] = L
         return L
 
-    def WriteBarrier(self, P):
+    def WriteBarrier(self, P, R_Container=None, FieldName=None):
         if self.h_written(P):          # fast-path
             return P
         if not self.h_global(P):
             assert isinstance(R, Object)
         else:
             if self.h_possibly_outdated(P):
-                R = self.LatestGlobalRevision(P)
+                R = self.LatestGlobalRevision(P, R_Container, FieldName)
             else:
                 R = P
             W = self.Localize(R)
             return L
         return R
 
-    def ReadGlobalToLocal(self, R):
+    def ReadGlobalToLocal(self, R, R_Container=None, FieldName=None):
         L = self.global_to_local[R]
+        if R_Container is not None and not self.h_global(R_Container):
+            # fix the original field in-place, if R_Container is local
+            L_Container = R_Container
+            assert FieldName in L_Container._fields
+            self.store((L_Container, FieldName), L)
         return L
 
-    def PossiblyUpdateChain(self, G, R):
-        pass
+    def PossiblyUpdateChain(self, G, R, R_Container, FieldName):
+        if G is not R and random.random() < 0.1:
+            # compress the chain
+            while True:
+                G_next = self.h_revision(G)
+                if G_next is R:
+                    break
+                self.set_h_revision(G, R)
+                G = G_next
+            # update the original field
+            if R_Container is not None:
+                assert FieldName in R_Container._fields
+                self.store((R_Container, FieldName), R)
 
     # Validation
     # ------------------------------------
             for name in L._fields:
                 value = self.load((L, name))
                 if isinstance(value, Object) and value not in seen:
-                    value.add(seen)
+                    seen.add(value)
                     pending.append(value)
 
     # Committing
         print "cpu %d ABORT: %s" % (self._cpuindex, reason)
         assert not self.is_inevitable
         self.CancelLocks()
+        print "cpu %d ABORTED" % (self._cpuindex,)
         raise AbortAndRetry
 
     def UpdateChainHeads(self, cur_time):

File hack/stm/python/test_c2.py

-import py
+import py, os
 import random
 from c2 import CPU_C2, System, prebuilt_object, local_object
 
 
 def test_demo3_readonly_inevitable():
     test_demo3(INEVITABLE_MASK=7, READONLY_FRAC=2)
+
+
+def test_demo5(CPUClass=CPU_C2, LOOP=10, NUM_CPUS=1):
+    start = prebuilt_object(x=None, y=None)
+
+    def action(cpu, p):
+        o = local_object(cpu, x=None, y=None)
+        res = []
+        for i in range(3):
+            ptr = start
+            field = 'xy'[p < 0.5]
+            while p >= 0.1:
+                print p
+                assert p < 1.0
+                p *= 2.0
+                ptr = cpu.DirectReadBarrier(ptr)
+                if p >= 1.0:
+                    p -= 1.0
+                    field = 'x'
+                    res.append('r')
+                else:
+                    field = 'y'
+                    res.append('l')
+                next = cpu.read_field(ptr, field)
+                if next is None:
+                    break
+            print p
+            res.append(field.upper())
+            p = 0.99 - p
+            ptr = cpu.WriteBarrierFromReadReady(ptr)
+            cpu.write_field(ptr, field, o)
+            o = ptr
+        print ''.join(res)
+        return ''.join(res)
+
+    def get_cpu(cpunum):
+        def code(cpu):
+            for i in range(LOOP):
+                #print cpu, 'start transaction', i
+                os.write(1, 'cpu %s: start transaction %d\n' %
+                         (cpu._cpuindex, i))
+                r = random.random()
+                res = cpu.execute_transaction(action, cpu, r)
+                assert cpu.last_commit_time not in transaction_log
+                transaction_log[cpu.last_commit_time] = (r, res)
+        code.__name__ = 'code%d' % cpunum
+        return CPUClass(code)
+
+    def code_alone(cpu):
+        items = transaction_log.items()
+        items.sort()
+        for _, (r, res) in items:
+            res2 = cpu.execute_transaction(action, cpu, r)
+            assert res2 == res
+
+    system1 = System([CPUClass(code_alone)])
+    system2 = System([get_cpu(i) for i in range(NUM_CPUS)])
+    for i in range(100):
+        #
+        transaction_log = {}
+        print
+        print '********** System2 **********'
+        system2.run()
+        #
+        print
+        print '********** System1 **********'
+        system1.run()