Commits

Amaury Forgeot d'Arc committed fc46fb9

cpyext: Steal from CPython the implementation of TLS functions:
PyThread_set_key_value and friends.

Comments (0)

Files changed (4)

pypy/module/cpyext/api.py

     'PyObject_AsReadBuffer', 'PyObject_AsWriteBuffer', 'PyObject_CheckReadBuffer',
 
     'PyOS_getsig', 'PyOS_setsig',
+    'PyThread_create_key', 'PyThread_delete_key', 'PyThread_set_key_value',
+    'PyThread_get_key_value', 'PyThread_delete_key_value',
+    'PyThread_ReInitTLS',
 
     'PyStructSequence_InitType', 'PyStructSequence_New',
 ]
         lambda space: init_pycobject(),
         lambda space: init_capsule(),
     ])
+    from pypy.module.posix.interp_posix import add_fork_hook
+    reinit_tls = rffi.llexternal('PyThread_ReInitTLS', [], lltype.Void,
+                                 compilation_info=eci)    
+    add_fork_hook('child', reinit_tls)
 
 def init_function(func):
     INIT_FUNCTIONS.append(func)
                                source_dir / "structseq.c",
                                source_dir / "capsule.c",
                                source_dir / "pysignals.c",
+                               source_dir / "thread.c",
                                ],
         separate_module_sources=separate_module_sources,
         export_symbols=export_symbols_eci,

pypy/module/cpyext/include/pythread.h

 
 #define WITH_THREAD
 
+#ifdef __cplusplus
+extern "C" {
+#endif
+
 typedef void *PyThread_type_lock;
 #define WAIT_LOCK	1
 #define NOWAIT_LOCK	0
 
+/* Thread Local Storage (TLS) API */
+PyAPI_FUNC(int) PyThread_create_key(void);
+PyAPI_FUNC(void) PyThread_delete_key(int);
+PyAPI_FUNC(int) PyThread_set_key_value(int, void *);
+PyAPI_FUNC(void *) PyThread_get_key_value(int);
+PyAPI_FUNC(void) PyThread_delete_key_value(int key);
+
+/* Cleanup after a fork */
+PyAPI_FUNC(void) PyThread_ReInitTLS(void);
+
+#ifdef __cplusplus
+}
 #endif
+
+#endif

pypy/module/cpyext/src/thread.c

+#include <Python.h>
+#include "pythread.h"
+
+/* ------------------------------------------------------------------------
+Per-thread data ("key") support.
+
+Use PyThread_create_key() to create a new key.  This is typically shared
+across threads.
+
+Use PyThread_set_key_value(thekey, value) to associate void* value with
+thekey in the current thread.  Each thread has a distinct mapping of thekey
+to a void* value.  Caution:  if the current thread already has a mapping
+for thekey, value is ignored.
+
+Use PyThread_get_key_value(thekey) to retrieve the void* value associated
+with thekey in the current thread.  This returns NULL if no value is
+associated with thekey in the current thread.
+
+Use PyThread_delete_key_value(thekey) to forget the current thread's associated
+value for thekey.  PyThread_delete_key(thekey) forgets the values associated
+with thekey across *all* threads.
+
+While some of these functions have error-return values, none set any
+Python exception.
+
+None of the functions does memory management on behalf of the void* values.
+You need to allocate and deallocate them yourself.  If the void* values
+happen to be PyObject*, these functions don't do refcount operations on
+them either.
+
+The GIL does not need to be held when calling these functions; they supply
+their own locking.  This isn't true of PyThread_create_key(), though (see
+next paragraph).
+
+There's a hidden assumption that PyThread_create_key() will be called before
+any of the other functions are called.  There's also a hidden assumption
+that calls to PyThread_create_key() are serialized externally.
+------------------------------------------------------------------------ */
+
+#ifdef MS_WINDOWS
+#include <windows.h>
+
+/* use native Windows TLS functions */
+#define Py_HAVE_NATIVE_TLS
+
+int
+PyThread_create_key(void)
+{
+    return (int) TlsAlloc();
+}
+
+void
+PyThread_delete_key(int key)
+{
+    TlsFree(key);
+}
+
+/* We must be careful to emulate the strange semantics implemented in thread.c,
+ * where the value is only set if it hasn't been set before.
+ */
+int
+PyThread_set_key_value(int key, void *value)
+{
+    BOOL ok;
+    void *oldvalue;
+
+    assert(value != NULL);
+    oldvalue = TlsGetValue(key);
+    if (oldvalue != NULL)
+        /* ignore value if already set */
+        return 0;
+    ok = TlsSetValue(key, value);
+    if (!ok)
+        return -1;
+    return 0;
+}
+
+void *
+PyThread_get_key_value(int key)
+{
+    /* because TLS is used in the Py_END_ALLOW_THREAD macro,
+     * it is necessary to preserve the windows error state, because
+     * it is assumed to be preserved across the call to the macro.
+     * Ideally, the macro should be fixed, but it is simpler to
+     * do it here.
+     */
+    DWORD error = GetLastError();
+    void *result = TlsGetValue(key);
+    SetLastError(error);
+    return result;
+}
+
+void
+PyThread_delete_key_value(int key)
+{
+    /* NULL is used as "key missing", and it is also the default
+     * given by TlsGetValue() if nothing has been set yet.
+     */
+    TlsSetValue(key, NULL);
+}
+
+/* reinitialization of TLS is not necessary after fork when using
+ * the native TLS functions.  And forking isn't supported on Windows either.
+ */
+void
+PyThread_ReInitTLS(void)
+{}
+
+#else  /* MS_WINDOWS */
+
+/* A singly-linked list of struct key objects remembers all the key->value
+ * associations.  File static keyhead heads the list.  keymutex is used
+ * to enforce exclusion internally.
+ */
+struct key {
+    /* Next record in the list, or NULL if this is the last record. */
+    struct key *next;
+
+    /* The thread id, according to PyThread_get_thread_ident(). */
+    long id;
+
+    /* The key and its associated value. */
+    int key;
+    void *value;
+};
+
+static struct key *keyhead = NULL;
+static PyThread_type_lock keymutex = NULL;
+static int nkeys = 0;  /* PyThread_create_key() hands out nkeys+1 next */
+
+/* Internal helper.
+ * If the current thread has a mapping for key, the appropriate struct key*
+ * is returned.  NB:  value is ignored in this case!
+ * If there is no mapping for key in the current thread, then:
+ *     If value is NULL, NULL is returned.
+ *     Else a mapping of key to value is created for the current thread,
+ *     and a pointer to a new struct key* is returned; except that if
+ *     malloc() can't find room for a new struct key*, NULL is returned.
+ * So when value==NULL, this acts like a pure lookup routine, and when
+ * value!=NULL, this acts like dict.setdefault(), returning an existing
+ * mapping if one exists, else creating a new mapping.
+ *
+ * Caution:  this used to be too clever, trying to hold keymutex only
+ * around the "p->next = keyhead; keyhead = p" pair.  That allowed
+ * another thread to mutate the list, via key deletion, concurrent with
+ * find_key() crawling over the list.  Hilarity ensued.  For example, when
+ * the for-loop here does "p = p->next", p could end up pointing at a
+ * record that PyThread_delete_key_value() was concurrently free()'ing.
+ * That could lead to anything, from failing to find a key that exists, to
+ * segfaults.  Now we lock the whole routine.
+ */
+static struct key *
+find_key(int key, void *value)
+{
+    struct key *p, *prev_p;
+    long id = PyThread_get_thread_ident();
+
+    if (!keymutex)
+        return NULL;
+    PyThread_acquire_lock(keymutex, 1);
+    prev_p = NULL;
+    for (p = keyhead; p != NULL; p = p->next) {
+        if (p->id == id && p->key == key)
+            goto Done;
+        /* Sanity check.  These states should never happen but if
+         * they do we must abort.  Otherwise we'll end up spinning in
+         * in a tight loop with the lock held.  A similar check is done
+         * in pystate.c tstate_delete_common().  */
+        if (p == prev_p)
+            Py_FatalError("tls find_key: small circular list(!)");
+        prev_p = p;
+        if (p->next == keyhead)
+            Py_FatalError("tls find_key: circular list(!)");
+    }
+    if (value == NULL) {
+        assert(p == NULL);
+        goto Done;
+    }
+    p = (struct key *)malloc(sizeof(struct key));
+    if (p != NULL) {
+        p->id = id;
+        p->key = key;
+        p->value = value;
+        p->next = keyhead;
+        keyhead = p;
+    }
+ Done:
+    PyThread_release_lock(keymutex);
+    return p;
+}
+
+/* Return a new key.  This must be called before any other functions in
+ * this family, and callers must arrange to serialize calls to this
+ * function.  No violations are detected.
+ */
+int
+PyThread_create_key(void)
+{
+    /* All parts of this function are wrong if it's called by multiple
+     * threads simultaneously.
+     */
+    if (keymutex == NULL)
+        keymutex = PyThread_allocate_lock();
+    return ++nkeys;
+}
+
+/* Forget the associations for key across *all* threads. */
+void
+PyThread_delete_key(int key)
+{
+    struct key *p, **q;
+
+    PyThread_acquire_lock(keymutex, 1);
+    q = &keyhead;
+    while ((p = *q) != NULL) {
+        if (p->key == key) {
+            *q = p->next;
+            free((void *)p);
+            /* NB This does *not* free p->value! */
+        }
+        else
+            q = &p->next;
+    }
+    PyThread_release_lock(keymutex);
+}
+
+/* Confusing:  If the current thread has an association for key,
+ * value is ignored, and 0 is returned.  Else an attempt is made to create
+ * an association of key to value for the current thread.  0 is returned
+ * if that succeeds, but -1 is returned if there's not enough memory
+ * to create the association.  value must not be NULL.
+ */
+int
+PyThread_set_key_value(int key, void *value)
+{
+    struct key *p;
+
+    assert(value != NULL);
+    p = find_key(key, value);
+    if (p == NULL)
+        return -1;
+    else
+        return 0;
+}
+
+/* Retrieve the value associated with key in the current thread, or NULL
+ * if the current thread doesn't have an association for key.
+ */
+void *
+PyThread_get_key_value(int key)
+{
+    struct key *p = find_key(key, NULL);
+
+    if (p == NULL)
+        return NULL;
+    else
+        return p->value;
+}
+
+/* Forget the current thread's association for key, if any. */
+void
+PyThread_delete_key_value(int key)
+{
+    long id = PyThread_get_thread_ident();
+    struct key *p, **q;
+
+    PyThread_acquire_lock(keymutex, 1);
+    q = &keyhead;
+    while ((p = *q) != NULL) {
+        if (p->key == key && p->id == id) {
+            *q = p->next;
+            free((void *)p);
+            /* NB This does *not* free p->value! */
+            break;
+        }
+        else
+            q = &p->next;
+    }
+    PyThread_release_lock(keymutex);
+}
+
+/* Forget everything not associated with the current thread id.
+ * This function is called from PyOS_AfterFork().  It is necessary
+ * because other thread ids which were in use at the time of the fork
+ * may be reused for new threads created in the forked process.
+ */
+void
+PyThread_ReInitTLS(void)
+{
+    long id = PyThread_get_thread_ident();
+    struct key *p, **q;
+
+    if (!keymutex)
+        return;
+
+    /* As with interpreter_lock in PyEval_ReInitThreads()
+       we just create a new lock without freeing the old one */
+    keymutex = PyThread_allocate_lock();
+
+    /* Delete all keys which do not match the current thread id */
+    q = &keyhead;
+    while ((p = *q) != NULL) {
+        if (p->id != id) {
+            *q = p->next;
+            free((void *)p);
+            /* NB This does *not* free p->value! */
+        }
+        else
+            q = &p->next;
+    }
+}
+
+#endif  /* !MS_WINDOWS */

pypy/module/cpyext/test/test_thread.py

 
 from pypy.module.thread.ll_thread import allocate_ll_lock
 from pypy.module.cpyext.test.test_api import BaseApiTest
+from pypy.module.cpyext.test.test_cpyext import AppTestCpythonExtensionBase
 
 
 class TestPyThread(BaseApiTest):
         api.PyThread_release_lock(lock)
         assert api.PyThread_acquire_lock(lock, 0) == 1
         api.PyThread_free_lock(lock)
+
+
+class AppTestThread(AppTestCpythonExtensionBase):
+    def test_tls(self):
+        module = self.import_extension('foo', [
+            ("create_key", "METH_NOARGS",
+             """
+                 return PyInt_FromLong(PyThread_create_key());
+             """),
+            ("test_key", "METH_O",
+             """
+                 int key = PyInt_AsLong(args);
+                 if (PyThread_get_key_value(key) != NULL) {
+                     PyErr_SetNone(PyExc_ValueError);
+                     return NULL;
+                 }
+                 if (PyThread_set_key_value(key, (void*)123) < 0) {
+                     PyErr_SetNone(PyExc_ValueError);
+                     return NULL;
+                 }
+                 if (PyThread_get_key_value(key) != (void*)123) {
+                     PyErr_SetNone(PyExc_ValueError);
+                     return NULL;
+                 }
+                 Py_RETURN_NONE;
+             """),
+            ])
+        key = module.create_key()
+        assert key > 0
+        # Test value in main thread.
+        module.test_key(key)
+        raises(ValueError, module.test_key, key)
+        # Same test, in another thread.
+        result = []
+        import thread, time
+        def in_thread():
+            try:
+                module.test_key(key)
+                raises(ValueError, module.test_key, key)
+            except Exception, e:
+                result.append(e)
+            else:
+                result.append(True)
+        thread.start_new_thread(in_thread, ())
+        while not result:
+            print "."
+            time.sleep(.5)
+        assert result == [True]