Commits

Bryan O'Sullivan committed 37129b0

Tons of code changes and some doc notes.

In brief:

Allow both blocking and non-blocking use of the C API.

Document the rules (at least as I imagine them to be) for object lifetime
management.

A pony.

  • Participants
  • Parent commits 6048001

Comments (0)

Files changed (2)

File Database/MySQL.hs

 
 module Database.MySQL
     (
+    -- * Resource management
+    -- $mgmt
     -- * Types
       ConnectInfo(..)
     , Option(..)
     , defaultConnectInfo
     , Connection
-    , Result(resConnection, resFields)
+    , Result
     , Field
     , Type
     , MySQLError(errFunction, errNumber, errMessage)
     , fieldCount
     , affectedRows
     -- * Working with results
+    , isResultValid
+    , freeResult
     , storeResult
     , useResult
     , fetchRow
     , clientVersion
     ) where
 
-import Data.ByteString.Char8 (ByteString)
+import Data.ByteString.Char8 ()
 import Data.ByteString.Internal
 import Data.ByteString.Unsafe
 import Database.MySQL.Types
+import System.Mem.Weak
     
 import Control.Applicative
 import Data.Int
 import Foreign.Marshal.Array
 import Foreign.Ptr
 
+-- $mgmt
+--
+-- Our rules for managing 'Connection' and 'Result' values are
+-- unfortunately complicated, thanks to MySQL's lifetime rules.
+--
+-- At the C @libmysqlclient@ level, a single @MYSQL@ connection may
+-- cause multiple @MYSQL_RES@ result values to be created over the
+-- course of multiple queries, but only one of these @MYSQL_RES@
+-- values may be alive at a time.  The programmer is responsible for
+-- knowing when to call @mysql_free_result@.
+--
+-- Meanwhile, up in Haskell-land, we'd like both 'Connection' and
+-- 'Result' values to be managed either manually or automatically. In
+-- particular, we want finalizers to tidy up after a messy programmer,
+-- and we'd prefer it if people didn't need to be mindful of calling
+-- @mysql_free_result@. This means that we must wrestle with the
+-- lifetime rules. An obvious approach would be to use some monad and
+-- type magic to enforce those rules, but then we'd end up with an
+-- awkward API.
+--
+-- Instead, we allow 'Result' values to stay alive for arbitrarily
+-- long times, while preserving the right to mark them as
+-- invalid. Since all functions over @Result@ values are in the 'IO'
+-- monad, we don't risk disrupting pure code by introducing this
+-- mutability. Code that tries to access a @Result@ that fails
+-- 'isResultValid' will be thrown a 'MySQLError'. This should /not/
+-- occur in normal code, so there should be no need to test a @Result@
+-- for validity.
+--
+-- A 'Result' must be able to keep a 'Connection' alive so that a
+-- streaming @Result@ constructed by 'useResult' can continue to pull
+-- data from the server, but a @Connection@ must (a) be able to cause
+-- the @MYSQL_RES@ behind a @Result@ to be deleted at a moment's notice,
+-- while (b) not artificially prolonging the life of either the @Result@
+-- or its @MYSQL_RES@.
+
 data ConnectInfo = ConnectInfo {
       connectHost :: String
     , connectPort :: Word16
       errFunction :: String
     , errNumber :: Int
     , errMessage :: String
+    } | ResultError {
+      errFunction :: String
+    , errNumber :: Int
+    , errMessage :: String
     } deriving (Eq, Show, Typeable)
 
 instance Exception MySQLError
 
 data Connection = Connection {
       connFP :: ForeignPtr MYSQL
-    , connClose :: Ptr MYSQL -> IO ()
+    , connClose :: IO ()
+    , connResult :: IORef (Maybe (Weak Result))
     }
 
 data Result = Result {
       resFP :: ForeignPtr MYSQL_RES
     , resFields :: {-# UNPACK #-} !Int
     , resConnection :: Connection
-    }
+    , resValid :: IORef Bool
+    , resFetchFields :: Ptr MYSQL_RES -> IO (Ptr Field)
+    , resFetchRow :: Ptr MYSQL_RES -> IO MYSQL_ROW
+    , resFetchLengths :: Ptr MYSQL_RES -> IO (Ptr CULong)
+    } | EmptyResult
 
 data Option = Option
             deriving (Eq, Read, Show, Typeable)
                                  (fromIntegral connectPort)
   when (ptr == nullPtr) $
     connectionError_ "connect" ptr0
-  fp <- newForeignPtr ptr $ realClose closed ptr
+  res <- newIORef Nothing
+  let realClose = do
+        cleanupConnResult res
+        wasClosed <- atomicModifyIORef closed $ \prev -> (True, prev)
+        unless wasClosed . withRTSSignalsBlocked $ mysql_close ptr
+  fp <- newForeignPtr ptr realClose
   return Connection {
                connFP = fp
-             , connClose = realClose closed
+             , connClose = realClose
+             , connResult = res
              }
 
+-- | Delete the 'MYSQL_RES' behind a 'Result' immediately, and mark
+-- the 'Result' as invalid.
+cleanupConnResult :: IORef (Maybe (Weak Result)) -> IO ()
+cleanupConnResult res = do
+  prev <- readIORef res
+  case prev of
+    Nothing -> return ()
+    Just w -> maybe (return ()) freeResult =<< deRefWeak w
+
 close :: Connection -> IO ()
-close conn = withConn conn (connClose conn)
-
-realClose :: IORef Bool -> Ptr MYSQL -> IO ()
-realClose closeInfo ptr = do
-  wasClosed <- atomicModifyIORef closeInfo $ \prev -> (True, prev)
-  unless wasClosed . withRTSSignalsBlocked $ mysql_close ptr
+close = connClose
+{-# INLINE close #-}
 
 ping :: Connection -> IO ()
 ping conn = withConn conn $ \ptr ->
   unsafeUseAsCStringLen q $ \(p,l) ->
   mysql_real_query ptr p (fromIntegral l) >>= check "query" conn
 
-fieldCount :: Connection -> IO Int
-fieldCount conn = withConn conn $ fmap fromIntegral . mysql_field_count
+fieldCount :: Either Connection Result -> IO Int
+fieldCount (Right EmptyResult) = return 0
+fieldCount (Right res)         = return (resFields res)
+fieldCount (Left conn)         =
+    withConn conn $ fmap fromIntegral . mysql_field_count
 
 affectedRows :: Connection -> IO Int64
 affectedRows conn = withConn conn $ fmap fromIntegral . mysql_affected_rows
 
-storeResult :: Connection -> IO (Maybe Result)
+storeResult :: Connection -> IO Result
 storeResult = frobResult "storeResult" mysql_store_result
+              mysql_fetch_fields_nonblock
+              mysql_fetch_row_nonblock
+              mysql_fetch_lengths_nonblock
 
-useResult :: Connection -> IO (Maybe Result)
+useResult :: Connection -> IO Result
 useResult = frobResult "useResult" mysql_use_result
+            (withRTSSignalsBlocked . mysql_fetch_fields)
+            (withRTSSignalsBlocked . mysql_fetch_row)
+            (withRTSSignalsBlocked . mysql_fetch_lengths)
 
-frobResult :: String -> (Ptr MYSQL -> IO (Ptr MYSQL_RES))
-           -> Connection -> IO (Maybe Result)
-frobResult func frob conn = withConn conn $ \ptr -> do
-  res <- withRTSSignalsBlocked $ frob ptr
-  fields <- mysql_field_count ptr
-  if res == nullPtr
-    then if fields == 0
-         then return Nothing
-         else connectionError func conn
-    else do
-      fp <- newForeignPtr res $ mysql_free_result res
-      return . Just $ Result {
-                   resFP = fp
-                 , resFields = fromIntegral fields
-                 , resConnection = conn
-                 }
+frobResult :: String
+           -> (Ptr MYSQL -> IO (Ptr MYSQL_RES))
+           -> (Ptr MYSQL_RES -> IO (Ptr Field))
+           -> (Ptr MYSQL_RES -> IO MYSQL_ROW)
+           -> (Ptr MYSQL_RES -> IO (Ptr CULong))
+           -> Connection -> IO Result
+frobResult func frob fetchFieldsFunc fetchRowFunc fetchLengthsFunc conn =
+  withConn conn $ \ptr -> do
+    cleanupConnResult (connResult conn)
+    res <- withRTSSignalsBlocked $ frob ptr
+    fields <- mysql_field_count ptr
+    valid <- newIORef True
+    if res == nullPtr
+      then if fields == 0
+           then return EmptyResult
+           else connectionError func conn
+      else do
+        fp <- newForeignPtr res $ freeResult_ valid res
+        let ret = Result {
+                    resFP = fp
+                  , resFields = fromIntegral fields
+                  , resConnection = conn
+                  , resValid = valid
+                  , resFetchFields = fetchFieldsFunc
+                  , resFetchRow = fetchRowFunc
+                  , resFetchLengths = fetchLengthsFunc
+                  }
+        weak <- mkWeakPtr ret (Just (freeResult_ valid res))
+        writeIORef (connResult conn) (Just weak)
+        return ret
 
+freeResult :: Result -> IO ()
+freeResult Result{..}      = withForeignPtr resFP $ freeResult_ resValid
+freeResult EmptyResult{..} = return ()
+
+isResultValid :: Result -> IO Bool
+isResultValid Result{..}  = readIORef resValid
+isResultValid EmptyResult = return False
+            
+freeResult_ :: IORef Bool -> Ptr MYSQL_RES -> IO ()
+freeResult_ valid ptr = do
+  wasValid <- atomicModifyIORef valid $ \prev -> (False, prev)
+  when wasValid $ mysql_free_result ptr
+    
 fetchRow :: Result -> IO [Maybe ByteString]
-fetchRow res@Result{..}
-    | resFields == 0 = return []
-    | otherwise      = withRes res $ \ptr -> do
-  rowPtr <- withRTSSignalsBlocked $ mysql_fetch_row ptr
+fetchRow res@Result{..}  = withRes "fetchRow" res $ \ptr -> do
+  rowPtr <- resFetchRow ptr
   if rowPtr == nullPtr
     then return []
     else do
-      lenPtr <- mysql_fetch_lengths ptr
+      lenPtr <- resFetchLengths ptr
       checkNull "fetchRow" resConnection lenPtr
       let go len = withPtr $ \colPtr ->
                    create (fromIntegral len) $ \d ->
                    memcpy d (castPtr colPtr) (fromIntegral len)
       sequence =<< zipWith go <$> peekArray resFields lenPtr
                               <*> peekArray resFields rowPtr
+fetchRow EmptyResult{..} = return []
 
 fetchFields :: Result -> IO [Field]
-fetchFields res = withRes res $ \ptr -> do
-  fptr <- withRTSSignalsBlocked $ mysql_fetch_fields ptr
-  n <- fieldCount (resConnection res)
-  peekArray n fptr
+fetchFields res@Result{..} = withRes "fetchFields" res $ \ptr -> do
+  peekArray resFields =<< resFetchFields ptr
+fetchFields EmptyResult{..} = return []
 
 nextResult :: Connection -> IO Bool
 nextResult conn = withConn conn $ \ptr -> do
 withConn :: Connection -> (Ptr MYSQL -> IO a) -> IO a
 withConn conn = withForeignPtr (connFP conn)
 
-withRes :: Result -> (Ptr MYSQL_RES -> IO a) -> IO a
-withRes res = withForeignPtr (resFP res)
+withRes :: String -> Result -> (Ptr MYSQL_RES -> IO a) -> IO a
+withRes func res act = do
+  valid <- readIORef (resValid res)
+  unless valid . throw $ ResultError func 0 "result is no longer usable"
+  withForeignPtr (resFP res) act
 
 withString :: String -> (CString -> IO a) -> IO a
 withString [] act = act nullPtr

File Database/MySQL/C.hsc

     , mysql_store_result
     , mysql_use_result
     , mysql_fetch_lengths
+    , mysql_fetch_lengths_nonblock
     , mysql_fetch_row
+    , mysql_fetch_row_nonblock
     -- * Working with results
     , mysql_free_result
     , mysql_fetch_fields
+    , mysql_fetch_fields_nonblock
     -- ** Multiple results
     , mysql_next_result
     -- * General information
 foreign import ccall unsafe mysql_fetch_fields
     :: Ptr MYSQL_RES -> IO (Ptr Field)
 
+foreign import ccall safe "mysql.h mysql_fetch_fields" mysql_fetch_fields_nonblock
+    :: Ptr MYSQL_RES -> IO (Ptr Field)
+
 foreign import ccall unsafe mysql_next_result
     :: Ptr MYSQL -> IO CInt
 
 foreign import ccall unsafe mysql_fetch_row
     :: Ptr MYSQL_RES -> IO MYSQL_ROW
 
+foreign import ccall safe "mysql.h mysql_fetch_row" mysql_fetch_row_nonblock
+    :: Ptr MYSQL_RES -> IO MYSQL_ROW
+
 foreign import ccall unsafe mysql_fetch_lengths
     :: Ptr MYSQL_RES -> IO (Ptr CULong)
 
+foreign import ccall safe "mysql.h mysql_fetch_lengths" mysql_fetch_lengths_nonblock
+    :: Ptr MYSQL_RES -> IO (Ptr CULong)
+
 foreign import ccall safe mysql_real_escape_string
     :: Ptr MYSQL -> CString -> CString -> CULong -> IO CULong