Commits

Bryan O'Sullivan  committed 5292a58

Avoid the use of runInBoundThread.

Instead, wrap janky signal-unsafe MySQL API calls with signal block/unblock
actions at the C level. This should be faster than using bound threads, and
also avoids the perils of running out of OS threads documented in
http://hackage.haskell.org/trac/ghc/ticket/5174

  • Participants
  • Parent commits ccf021c

Comments (0)

Files changed (5)

File Database/MySQL/Base.hs

            withString connectPassword $ \cpass ->
             withString connectDatabase $ \cdb ->
              withString connectPath $ \cpath ->
-              withRTSSignalsBlocked $
                mysql_real_connect ptr0 chost cuser cpass cdb
                                   (fromIntegral connectPort) cpath flags
   when (ptr == nullPtr) $
   let realClose = do
         cleanupConnResult res
         wasClosed <- atomicModifyIORef closed $ \prev -> (True, prev)
-        unless wasClosed . withRTSSignalsBlocked $ mysql_close ptr
+        unless wasClosed $ mysql_close ptr
   fp <- newForeignPtr ptr realClose
   return Connection {
                connFP = fp
 {-# INLINE close #-}
 
 ping :: Connection -> IO ()
-ping conn = withConn conn $ \ptr ->
-            withRTSSignalsBlocked (mysql_ping ptr) >>= check "ping" conn
+ping conn = withConn conn $ \ptr -> mysql_ping ptr >>= check "ping" conn
 
 threadId :: Connection -> IO Word
 threadId conn = fromIntegral <$> withConn conn mysql_thread_id
 
 serverStatus :: Connection -> IO String
 serverStatus conn = withConn conn $ \ptr -> do
-  st <- withRTSSignalsBlocked $ mysql_stat ptr
+  st <- mysql_stat ptr
   checkNull "serverStatus" conn st
   peekCString st
 
 -- permanently.
 autocommit :: Connection -> Bool -> IO ()
 autocommit conn onOff = withConn conn $ \ptr ->
-   withRTSSignalsBlocked (mysql_autocommit ptr b) >>= check "autocommit" conn
+   mysql_autocommit ptr b >>= check "autocommit" conn
  where b = if onOff then 1 else 0
 
 changeUser :: Connection -> String -> String -> Maybe String -> IO ()
    withCString pass $ \cpass ->
     withMaybeString mdb $ \cdb ->
      withConn conn $ \ptr ->
-      withRTSSignalsBlocked (mysql_change_user ptr cuser cpass cdb) >>=
-      check "changeUser" conn
+      mysql_change_user ptr cuser cpass cdb >>= check "changeUser" conn
 
 selectDB :: Connection -> String -> IO ()
-selectDB conn db = 
+selectDB conn db =
   withCString db $ \cdb ->
     withConn conn $ \ptr ->
-      withRTSSignalsBlocked (mysql_select_db ptr cdb) >>= check "selectDB" conn
+      mysql_select_db ptr cdb >>= check "selectDB" conn
 
 query :: Connection -> ByteString -> IO ()
 query conn q = withConn conn $ \ptr ->
 -- Any previous outstanding 'Result' is first marked as invalid.
 useResult :: Connection -> IO Result
 useResult = frobResult "useResult" mysql_use_result
-            (withRTSSignalsBlocked . mysql_fetch_fields)
-            (withRTSSignalsBlocked . mysql_fetch_row)
-            (withRTSSignalsBlocked . mysql_fetch_lengths)
+            mysql_fetch_fields
+            mysql_fetch_row
+            mysql_fetch_lengths
 
 frobResult :: String
            -> (Ptr MYSQL -> IO (Ptr MYSQL_RES))
 frobResult func frob fetchFieldsFunc fetchRowFunc fetchLengthsFunc conn =
   withConn conn $ \ptr -> do
     cleanupConnResult (connResult conn)
-    res <- withRTSSignalsBlocked $ frob ptr
+    res <- frob ptr
     fields <- mysql_field_count ptr
     valid <- newIORef True
     if res == nullPtr
 isResultValid :: Result -> IO Bool
 isResultValid Result{..}  = readIORef resValid
 isResultValid EmptyResult = return False
-            
+
 freeResult_ :: IORef Bool -> Ptr MYSQL_RES -> IO ()
 freeResult_ valid ptr = do
   wasValid <- atomicModifyIORef valid $ \prev -> (False, prev)
   when wasValid $ mysql_free_result ptr
-    
+
 fetchRow :: Result -> IO [Maybe ByteString]
 fetchRow res@Result{..}  = withRes "fetchRow" res $ \ptr -> do
   rowPtr <- resFetchRow ptr
 nextResult :: Connection -> IO Bool
 nextResult conn = withConn conn $ \ptr -> do
   cleanupConnResult (connResult conn)
-  i <- withRTSSignalsBlocked $ mysql_next_result ptr
+  i <- mysql_next_result ptr
   case i of
     0  -> return True
     -1 -> return False

File Database/MySQL/Base/C.hsc

     -- * Error handling
     , mysql_errno
     , mysql_error
-    -- * Support functions
-    , withRTSSignalsBlocked
     ) where
 
+#include "mysql_signals.h"
 #include "mysql.h"
-#include <signal.h>
 
-import Control.Concurrent (rtsSupportsBoundThreads, runInBoundThread)
-import Control.Exception (finally)
 import Data.ByteString.Unsafe (unsafeUseAsCString)
 import Database.MySQL.Base.Types
 import Foreign.C.String (CString, withCString)
 import Foreign.C.Types (CInt, CUInt, CULLong, CULong)
-import Foreign.ForeignPtr (ForeignPtr, mallocForeignPtr, withForeignPtr)
 import Foreign.Marshal.Utils (with)
 import Foreign.Ptr (Ptr, nullPtr)
-import Foreign.Storable (Storable(..))
-import System.IO.Unsafe (unsafePerformIO)
-
--- | Execute an 'IO' action with signals used by GHC's runtime signals
--- blocked.  The @mysqlclient@ C library does not correctly restart
--- system calls if they are interrupted by signals, so many MySQL API
--- calls can unexpectedly fail when called from a Haskell application.
--- This is most likely to occur if you are linking against GHC's
--- threaded runtime (using the @-threaded@ option).
---
--- This function blocks @SIGALRM@ and @SIGVTALRM@, runs your action,
--- then unblocks those signals.  If you have a series of HDBC calls
--- that may block for a period of time, it may be wise to wrap them in
--- this action.  Blocking and unblocking signals is cheap, but not
--- free.
---
--- Here is an example of an exception that could be avoided by
--- temporarily blocking GHC's runtime signals:
---
--- >  SqlError {
--- >    seState = "", 
--- >    seNativeError = 2003, 
--- >    seErrorMsg = "Can't connect to MySQL server on 'localhost' (4)"
--- >  }
-withRTSSignalsBlocked :: IO a -> IO a
-withRTSSignalsBlocked act
-    | not rtsSupportsBoundThreads = act
-    | otherwise = runInBoundThread . withForeignPtr rtsSignals $ \set -> do
-  pthread_sigmask (#const SIG_BLOCK) set nullPtr
-  act `finally` pthread_sigmask (#const SIG_UNBLOCK) set nullPtr
-
-rtsSignals :: ForeignPtr SigSet
-rtsSignals = unsafePerformIO $ do
-               fp <- mallocForeignPtr
-               withForeignPtr fp $ \set -> do
-                 sigemptyset set
-                 sigaddset set (#const SIGALRM)
-                 sigaddset set (#const SIGVTALRM)
-               return fp
-{-# NOINLINE rtsSignals #-}
-
-data SigSet
-
-instance Storable SigSet where
-    sizeOf    _ = #{size sigset_t}
-    alignment _ = alignment (undefined :: Ptr CInt)
-
-foreign import ccall unsafe "signal.h sigaddset" sigaddset
-    :: Ptr SigSet -> CInt -> IO ()
-
-foreign import ccall unsafe "signal.h sigemptyset" sigemptyset
-    :: Ptr SigSet -> IO ()
-
-foreign import ccall unsafe "signal.h pthread_sigmask" pthread_sigmask
-    :: CInt -> Ptr SigSet -> Ptr SigSet -> IO ()
 
 foreign import ccall safe mysql_init
     :: Ptr MYSQL                -- ^ should usually be 'nullPtr'
 foreign import ccall safe "mysql.h mysql_options" mysql_options_
     :: Ptr MYSQL -> CInt -> Ptr a -> IO CInt
 
-foreign import ccall unsafe mysql_real_connect
+foreign import ccall unsafe "mysql_signals.h _hs_mysql_real_connect"
+        mysql_real_connect
     :: Ptr MYSQL -- ^ Context (from 'mysql_init').
     -> CString   -- ^ Host name.
     -> CString   -- ^ User name.

File cbits/mysql_signals.c

+/*
+ * Wrap MySQL API calls that are known to block and to be vulnerable
+ * to interruption by GHC's RTS signals.
+ */
+
+#include "mysql_signals.h"
+#include <pthread.h>
+#include <signal.h>
+#include <stdio.h>
+
+static sigset_t sigs[1];
+static int sigs_inited;
+
+static void init_rts_sigset(void)
+{
+    static pthread_mutex_t sigs_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+    pthread_mutex_lock(&sigs_mutex);
+    if (!sigs_inited) {
+	sigemptyset(sigs);
+	sigaddset(sigs, SIGALRM);
+	sigaddset(sigs, SIGVTALRM);
+	sigs_inited = 1;
+    }
+    pthread_mutex_unlock(&sigs_mutex);
+}
+
+#define block_rts_signals() \
+    do { \
+        if (!sigs_inited) init_rts_sigset(); \
+        pthread_sigmask(SIG_BLOCK, sigs, NULL);	\
+    } while (0)
+
+#define unblock_rts_signals() pthread_sigmask(SIG_UNBLOCK, sigs, NULL)
+
+MYSQL *STDCALL _hs_mysql_real_connect(MYSQL *mysql, const char *host,
+				      const char *user,
+				      const char *passwd,
+				      const char *db,
+				      unsigned int port,
+				      const char *unix_socket,
+				      unsigned long clientflag)
+{
+    MYSQL *ret;
+    block_rts_signals();
+    ret = mysql_real_connect(mysql, host, user, passwd, db, port, unix_socket,
+			     clientflag);
+    unblock_rts_signals();
+
+    return ret;
+}
+
+void STDCALL _hs_mysql_close(MYSQL *sock)
+{
+    block_rts_signals();
+    mysql_close(sock);
+    unblock_rts_signals();
+}
+
+int STDCALL _hs_mysql_ping(MYSQL *mysql)
+{
+    int ret;
+    block_rts_signals();
+    ret = mysql_ping(mysql);
+    unblock_rts_signals();
+    return ret;
+}
+
+const char *STDCALL _hs_mysql_stat(MYSQL *mysql)
+{
+    const char *ret;
+    block_rts_signals();
+    ret = mysql_stat(mysql);
+    unblock_rts_signals();
+    return ret;
+}
+
+my_bool STDCALL _hs_mysql_autocommit(MYSQL *mysql, my_bool auto_mode)
+{
+    my_bool ret;
+    block_rts_signals();
+    ret = mysql_autocommit(mysql, auto_mode);
+    unblock_rts_signals();
+    return ret;
+}
+
+my_bool STDCALL _hs_mysql_change_user(MYSQL *mysql, const char *user,
+				      const char *passwd, const char *db)
+{
+    my_bool ret;
+    block_rts_signals();
+    ret = mysql_change_user(mysql, user, passwd, db);
+    unblock_rts_signals();
+    return ret;
+}
+
+int STDCALL _hs_mysql_select_db(MYSQL *mysql, const char *db)
+{
+    int ret;
+    block_rts_signals();
+    ret = mysql_select_db(mysql, db);
+    unblock_rts_signals();
+    return ret;
+}
+
+MYSQL_FIELD *STDCALL _hs_mysql_fetch_field(MYSQL_RES *result)
+{
+    MYSQL_FIELD *ret;
+    block_rts_signals();
+    ret = mysql_fetch_field(result);
+    unblock_rts_signals();
+    return ret;
+}
+
+MYSQL_ROW STDCALL _hs_mysql_fetch_row(MYSQL_RES *result)
+{
+    MYSQL_ROW ret;
+    block_rts_signals();
+    ret = mysql_fetch_row(result);
+    unblock_rts_signals();
+    return ret;
+}
+
+unsigned long *STDCALL _hs_mysql_fetch_lengths(MYSQL_RES *result)
+{
+    unsigned long *ret;
+    block_rts_signals();
+    ret = mysql_fetch_lengths(result);
+    unblock_rts_signals();
+    return ret;
+}
+
+MYSQL_RES *STDCALL _hs_mysql_store_result(MYSQL *mysql)
+{
+    MYSQL_RES *ret;
+    block_rts_signals();
+    ret = mysql_store_result(mysql);
+    unblock_rts_signals();
+    return ret;
+}
+
+MYSQL_RES *STDCALL _hs_mysql_use_result(MYSQL *mysql)
+{
+    MYSQL_RES *ret;
+    block_rts_signals();
+    ret = mysql_use_result(mysql);
+    unblock_rts_signals();
+    return ret;
+}
+
+int STDCALL _hs_mysql_next_result(MYSQL *mysql)
+{
+    int ret;
+    block_rts_signals();
+    ret = mysql_next_result(mysql);
+    unblock_rts_signals();
+    return ret;
+}

File include/mysql_signals.h

+/*
+ * Wrappers for MySQL API calls that are known to block and to be
+ * vulnerable to interruption by GHC's RTS signals.
+ */
+
+#ifndef _mysql_signals_h
+#define _mysql_signals_h
+
+#include "mysql.h"
+
+MYSQL *STDCALL _hs_mysql_real_connect(MYSQL *mysql, const char *host,
+				      const char *user,
+				      const char *passwd,
+				      const char *db,
+				      unsigned int port,
+				      const char *unix_socket,
+				      unsigned long clientflag);
+void STDCALL _hs_mysql_close(MYSQL *sock);
+int STDCALL _hs_mysql_ping(MYSQL *mysql);
+const char *STDCALL _hs_mysql_stat(MYSQL *mysql);
+my_bool STDCALL _hs_mysql_autocommit(MYSQL * mysql, my_bool auto_mode);
+my_bool STDCALL _hs_mysql_change_user(MYSQL *mysql, const char *user,
+				      const char *passwd, const char *db);
+int STDCALL _hs_mysql_select_db(MYSQL *mysql, const char *db);
+MYSQL_FIELD *STDCALL _hs_mysql_fetch_field(MYSQL_RES *result);
+MYSQL_ROW STDCALL _hs_mysql_fetch_row(MYSQL_RES *result);
+unsigned long *STDCALL _hs_mysql_fetch_lengths(MYSQL_RES *result);
+MYSQL_RES *STDCALL _hs_mysql_store_result(MYSQL *mysql);
+MYSQL_RES *STDCALL _hs_mysql_use_result(MYSQL *mysql);
+int STDCALL _hs_mysql_next_result(MYSQL *mysql);
+
+#endif /* _mysql_signals_h */
 name:           mysql
-version:        0.1.0.1
+version:        0.1.1.0
 homepage:       https://github.com/mailrank/mysql
 bug-reports:    https://github.com/mailrank/mysql/issues
 synopsis:       A low-level MySQL client library.
   default: False
 
 library
+  c-sources: cbits/mysql_signals.c
+
+  include-dirs: include
+
   exposed-modules:
     Database.MySQL.Base
     Database.MySQL.Base.C