Commits

Armin Rigo  committed 5ccd3d3

Thread-local data in the transaction module. See docstring for
two foreseen usage patterns.

  • Participants
  • Parent commits 5caa1ed
  • Branches stm-gc

Comments (0)

Files changed (4)

File pypy/module/transaction/__init__.py

         'run': 'interp_transaction.run',
         'add_epoll': 'interp_epoll.add_epoll',        # xxx linux only
         'remove_epoll': 'interp_epoll.remove_epoll',  # xxx linux only
+        'local': 'interp_local.W_Local',
     }
 
     appleveldefs = {

File pypy/module/transaction/interp_local.py

+from pypy.interpreter.baseobjspace import Wrappable
+from pypy.interpreter.typedef import (TypeDef, interp2app, GetSetProperty,
+    descr_get_dict)
+from pypy.module.transaction.interp_transaction import state
+
+
+class W_Local(Wrappable):
+    """Thread-local data.  Behaves like a regular object, but its content
+    is not shared between multiple concurrently-running transactions.
+    It can be accessed without conflicts.
+
+    It can be used for purely transaction-local data.
+
+    It can also be used for long-living caches that store values that
+    are (1) not too costly to compute and (2) not too memory-hungry,
+    because they will end up being computed and stored once per actual
+    thread.
+    """
+
+    def __init__(self, space):
+        self.dicts = []
+        self._update_dicts(space)
+        # unless we call transaction.set_num_threads() afterwards, this
+        # 'local' object is now initialized with the correct number of
+        # dictionaries, to avoid conflicts later if _update_dicts() is
+        # called in a transaction.
+
+    def _update_dicts(self, space):
+        new = state.get_number_of_threads() - len(self.dicts)
+        if new <= 0:
+            return
+        # update the list without appending to it (to keep it non-resizable)
+        self.dicts = self.dicts + [space.newdict(instance=True)
+                                   for i in range(new)]
+
+    def getdict(self, space):
+        n = state.get_thread_number()
+        try:
+            return self.dicts[n]
+        except IndexError:
+            self._update_dicts(space)
+            assert n < len(self.dicts)
+            return self.dicts[n]
+
+def descr_local__new__(space, w_subtype):
+    local = W_Local(space)
+    return space.wrap(local)
+
+W_Local.typedef = TypeDef("transaction.local",
+            __new__ = interp2app(descr_local__new__),
+            __dict__ = GetSetProperty(descr_get_dict, cls=W_Local),
+            )
+W_Local.typedef.acceptable_as_base_class = False

File pypy/module/transaction/interp_transaction.py

         self.ll_no_tasks_pending_lock = threadintf.null_ll_lock
         self.ll_unfinished_lock = threadintf.null_ll_lock
         self.threadobjs = {}      # empty during translation
+        self.threadnums = {}      # empty during translation
         self.epolls = None
         self.pending = Fifo()
 
     def _freeze_(self):
         self.threadobjs.clear()
+        self.threadnums.clear()
         return False
 
     def startup(self, space, w_module):
         assert id not in self.threadobjs
         ec._transaction_pending = Fifo()
         self.threadobjs[id] = ec
+        self.threadnums[id] = len(self.threadnums)
 
     # ---------- interface for ThreadLocals ----------
     # This works really like a thread-local, which may have slightly
         id = rstm.thread_id()
         assert id == MAIN_THREAD_ID   # should not be used from a transaction
         self.threadobjs[id] = value
+        self.threadnums = {id: 0}
 
     def getmainthreadvalue(self):
         return self.threadobjs.get(MAIN_THREAD_ID, None)
         for id in self.threadobjs.keys():
             if id != MAIN_THREAD_ID:
                 del self.threadobjs[id]
+        self.threadnums = {MAIN_THREAD_ID: 0}
+
+    def get_thread_number(self):
+        id = rstm.thread_id()
+        return self.threadnums[id]
+
+    def get_number_of_threads(self):
+        return 1 + self.num_threads
 
     # ----------
 

File pypy/module/transaction/test/test_local.py

+import py
+from pypy.conftest import gettestobjspace
+
+
+class AppTestLocal:
+    def setup_class(cls):
+        cls.space = gettestobjspace(usemodules=['transaction'])
+
+    def test_simple(self):
+        import transaction
+        x = transaction.local()
+        x.foo = 42
+        assert x.foo == 42
+        assert hasattr(x, 'foo')
+        assert not hasattr(x, 'bar')
+        assert getattr(x, 'foo', 84) == 42
+        assert getattr(x, 'bar', 84) == 84
+
+    def test_transaction_local(self):
+        import transaction
+        transaction.set_num_threads(2)
+        x = transaction.local()
+        all_lists = []
+
+        def f(n):
+            if not hasattr(x, 'lst'):
+                x.lst = []
+                all_lists.append(x.lst)
+            x.lst.append(n)
+            if n > 0:
+                transaction.add(f, n - 1)
+                transaction.add(f, n - 1)
+        transaction.add(f, 5)
+        transaction.run()
+
+        assert not hasattr(x, 'lst')
+        assert len(all_lists) == 2
+        total = all_lists[0] + all_lists[1]
+        assert total.count(5) == 1
+        assert total.count(4) == 2
+        assert total.count(3) == 4
+        assert total.count(2) == 8
+        assert total.count(1) == 16
+        assert total.count(0) == 32
+        assert len(total) == 63
+
+    def test_transaction_local_growing(self):
+        import transaction
+        transaction.set_num_threads(1)
+        x = transaction.local()
+        all_lists = []
+
+        def f(n):
+            if not hasattr(x, 'lst'):
+                x.lst = []
+                all_lists.append(x.lst)
+            x.lst.append(n)
+            if n > 0:
+                transaction.add(f, n - 1)
+                transaction.add(f, n - 1)
+        transaction.add(f, 5)
+
+        transaction.set_num_threads(2)    # more than 1 specified above
+        transaction.run()
+
+        assert not hasattr(x, 'lst')
+        assert len(all_lists) == 2
+        total = all_lists[0] + all_lists[1]
+        assert total.count(5) == 1
+        assert total.count(4) == 2
+        assert total.count(3) == 4
+        assert total.count(2) == 8
+        assert total.count(1) == 16
+        assert total.count(0) == 32
+        assert len(total) == 63