Bryan O'Sullivan avatar Bryan O'Sullivan committed a23a646

Field metadata.

Comments (0)

Files changed (3)

Database/MySQL.hs

     , Option(..)
     , defaultConnectInfo
     , Connection
+    , Result(resConnection)
+    , Field
+    , Type
     , MySQLError(errFunction, errNumber, errMessage)
     -- * Connection management
     , connect
     , serverStatus
     -- * Querying
     , query
+    -- ** Escaping
+    , escape
     -- ** Results
     , fieldCount
     , affectedRows
-    -- * Escaping
-    , escape
+    , storeResult
+    -- * Working with results
+    , fetchFields
     -- * General information
     , clientInfo
     , clientVersion
     ) where
 
-import Data.ByteString
+import Data.ByteString.Char8
 import Data.ByteString.Internal
 import Data.ByteString.Unsafe
     
 import Control.Applicative
+import Data.Int
 import Data.Typeable (Typeable)
 import Control.Exception
 import Control.Monad
 import Foreign.C.Types
 import Foreign.ForeignPtr hiding (newForeignPtr)
 import Foreign.Concurrent
+import Foreign.Marshal.Array
 import Foreign.Ptr
 
 data ConnectInfo = ConnectInfo {
     , connClose :: Ptr MYSQL -> IO ()
     }
 
+data Result = Result {
+      resFP :: ForeignPtr MYSQL_RES
+    , resConnection :: Connection
+    }
+
 data Option = Option
             deriving (Eq, Read, Show, Typeable)
 
   unsafeUseAsCStringLen q $ \(p,l) ->
   mysql_real_query ptr p (fromIntegral l) >>= check "query" ptr
 
-fieldCount :: Connection -> IO Word
+fieldCount :: Connection -> IO Int
 fieldCount conn = withConn conn $ fmap fromIntegral . mysql_field_count
 
-affectedRows :: Connection -> IO Word64
+affectedRows :: Connection -> IO Int64
 affectedRows conn = withConn conn $ fmap fromIntegral . mysql_affected_rows
 
+storeResult :: Connection -> IO (Maybe Result)
+storeResult conn = withConn conn $ \ptr -> do
+  res <- mysql_store_result ptr
+  if res == nullPtr
+    then do
+      n <- mysql_field_count ptr
+      if n == 0
+        then return Nothing
+        else connectionError "storeResult" ptr
+    else do
+      fp <- newForeignPtr res $ mysql_free_result res
+      return . Just $ Result {
+                   resFP = fp
+                 , resConnection = conn
+                 }
+
+fetchFields :: Result -> IO [Field]
+fetchFields res = withRes res $ \ptr -> do
+  fptr <- withRTSSignalsBlocked $ mysql_fetch_fields ptr
+  n <- fieldCount (resConnection res)
+  peekArray n fptr
+
 escape :: Connection -> ByteString -> IO ByteString
 escape conn bs = withConn conn $ \ptr ->
   unsafeUseAsCStringLen bs $ \(p,l) ->
 withConn :: Connection -> (Ptr MYSQL -> IO a) -> IO a
 withConn conn = withForeignPtr (connFP conn)
 
+withRes :: Result -> (Ptr MYSQL_RES -> IO a) -> IO a
+withRes res = withForeignPtr (resFP res)
+
 withString :: String -> (CString -> IO a) -> IO a
 withString [] act = act nullPtr
 withString xs act = withCString xs act

Database/MySQL/C.hsc

-{-# LANGUAGE EmptyDataDecls, ForeignFunctionInterface #-}
+{-# LANGUAGE DeriveDataTypeable, EmptyDataDecls, ForeignFunctionInterface #-}
 
 module Database.MySQL.C
     (
     -- * Types
-      MYSQL
+    -- * High-level types
+      Type(..)
+    , Field(..)
+    , FieldFlag
+    , FieldFlags
+    -- * Low-level types
+    , MYSQL
+    , MYSQL_RES
     , MYSQL_STMT
     , MyBool
     -- * Connection management
     , mysql_stat
     -- * Querying
     , mysql_real_query
+    -- ** Escaping
+    , mysql_real_escape_string
     -- ** Results
     , mysql_field_count
     , mysql_affected_rows
-    -- * Escaping
-    , mysql_real_escape_string
+    , mysql_store_result
+    -- * Working with results
+    , mysql_free_result
+    , mysql_fetch_fields
+    -- * Field flags
+    , hasAllFlags
+    , flagNotNull
+    , flagPrimaryKey
+    , flagUniqueKey
+    , flagMultipleKey
+    , flagUnsigned
+    , flagZeroFill
+    , flagBinary
+    , flagAutoIncrement
+    , flagNumeric
+    , flagNoDefaultValue
     -- * General information
     , mysql_get_client_info
     , mysql_get_client_version
 #include "mysql.h"
 #include <signal.h>
 
+import Data.Monoid
+import Data.Bits
+import Data.List
+import Control.Applicative
+import Data.Maybe
+import qualified Data.IntMap as IntMap
 import Control.Concurrent (rtsSupportsBoundThreads, runInBoundThread)
 import Control.Exception (finally)
 import Foreign.C.String (CString)
 import Foreign.C.Types
 import Foreign.ForeignPtr (ForeignPtr, mallocForeignPtr, withForeignPtr)
-import Foreign.Ptr (Ptr, nullPtr)
+import Foreign.Ptr (Ptr, castPtr, nullPtr)
 import Foreign.Storable (Storable(..))
 import System.IO.Unsafe (unsafePerformIO)
+import Foreign.Storable
+import Data.Typeable (Typeable)
+import Data.ByteString hiding (intercalate)
+import Data.ByteString.Internal
+import Data.Word
 
 data MYSQL
+data MYSQL_RES
 data MYSQL_STMT
 type MyBool = CChar
 
+-- | Column types supported by MySQL.
+data Type = Decimal
+          | Tiny
+          | Short
+          | Long
+          | Float
+          | Double
+          | Null
+          | Timestamp
+          | LongLong
+          | Int24
+          | Date
+          | Time
+          | DateTime
+          | Year
+          | NewDate
+          | VarChar
+          | Bit
+          | NewDecimal
+          | Enum
+          | Set
+          | TinyBlob
+          | MediumBlob
+          | LongBlob
+          | Blob
+          | VarString
+          | String
+          | Geometry
+            deriving (Enum, Eq, Show, Typeable)
+
+toType :: CInt -> Type
+toType v = IntMap.findWithDefault oops (fromIntegral v) typeMap
+  where
+    oops = error $ "Database.MySQL: unknown field type " ++ show v
+    typeMap = IntMap.fromList [
+               ((#const MYSQL_TYPE_DECIMAL), Decimal),
+               ((#const MYSQL_TYPE_TINY), Tiny),
+               ((#const MYSQL_TYPE_SHORT), Short),
+               ((#const MYSQL_TYPE_LONG), Long),
+               ((#const MYSQL_TYPE_FLOAT), Float),
+               ((#const MYSQL_TYPE_DOUBLE), Double),
+               ((#const MYSQL_TYPE_NULL), Null),
+               ((#const MYSQL_TYPE_TIMESTAMP), Timestamp),
+               ((#const MYSQL_TYPE_LONGLONG), LongLong),
+               ((#const MYSQL_TYPE_DATE), Date),
+               ((#const MYSQL_TYPE_TIME), Time),
+               ((#const MYSQL_TYPE_DATETIME), DateTime),
+               ((#const MYSQL_TYPE_YEAR), Year),
+               ((#const MYSQL_TYPE_NEWDATE), NewDate),
+               ((#const MYSQL_TYPE_VARCHAR), VarChar),
+               ((#const MYSQL_TYPE_BIT), Bit),
+               ((#const MYSQL_TYPE_NEWDECIMAL), NewDecimal),
+               ((#const MYSQL_TYPE_ENUM), Enum),
+               ((#const MYSQL_TYPE_SET), Set),
+               ((#const MYSQL_TYPE_TINY_BLOB), TinyBlob),
+               ((#const MYSQL_TYPE_MEDIUM_BLOB), MediumBlob),
+               ((#const MYSQL_TYPE_LONG_BLOB), LongBlob),
+               ((#const MYSQL_TYPE_BLOB), Blob),
+               ((#const MYSQL_TYPE_VAR_STRING), VarString),
+               ((#const MYSQL_TYPE_STRING), String),
+               ((#const MYSQL_TYPE_GEOMETRY), Geometry)
+              ]
+
+-- | A description of a field (column) of a table.
+data Field = Field {
+      fieldName :: ByteString   -- ^ Name of column.
+    , fieldOrigName :: ByteString -- ^ Original column name, if an alias.
+    , fieldTable :: ByteString -- ^ Table of column, if column was a field.
+    , fieldOrigTable :: ByteString -- ^ Original table name, if table was an alias.
+    , fieldDB :: ByteString        -- ^ Database for table.
+    , fieldCatalog :: ByteString   -- ^ Catalog for table.
+    , fieldDefault :: Maybe ByteString   -- ^ Default value.
+    , fieldLength :: Word          -- ^ Width of column (create length).
+    , fieldMaxLength :: Word    -- ^ Maximum width for selected set.
+    , fieldFlags :: FieldFlags        -- ^ Div flags.
+    , fieldDecimals :: Word -- ^ Number of decimals in field.
+    , fieldCharSet :: Word -- ^ Character set number.
+    , fieldType :: Type
+    } deriving (Eq, Show, Typeable)
+
+newtype FieldFlags = FieldFlags CUInt
+    deriving (Eq, Typeable)
+
+instance Show FieldFlags where
+    show f = '[' : z ++ "]"
+      where z = intercalate "," . catMaybes $ [
+                          flagNotNull ??? "flagNotNull"
+                        , flagPrimaryKey ??? "flagPrimaryKey"
+                        , flagUniqueKey ??? "flagUniqueKey"
+                        , flagMultipleKey ??? "flagMultipleKey"
+                        , flagUnsigned ??? "flagUnsigned"
+                        , flagZeroFill ??? "flagZeroFill"
+                        , flagBinary ??? "flagBinary"
+                        , flagAutoIncrement ??? "flagAutoIncrement"
+                        , flagNumeric ??? "flagNumeric"
+                        , flagNoDefaultValue ??? "flagNoDefaultValue"
+                        ]
+            flag ??? name | f `hasAllFlags` flag = Just name
+                          | otherwise            = Nothing
+
+type FieldFlag = FieldFlags
+
+instance Monoid FieldFlags where
+    mempty = FieldFlags 0
+    {-# INLINE mempty #-}
+    mappend (FieldFlags a) (FieldFlags b) = FieldFlags (a .|. b)
+    {-# INLINE mappend #-}
+
+flagNotNull, flagPrimaryKey, flagUniqueKey, flagMultipleKey :: FieldFlag
+flagNotNull = FieldFlags #const NOT_NULL_FLAG
+flagPrimaryKey = FieldFlags #const PRI_KEY_FLAG
+flagUniqueKey = FieldFlags #const UNIQUE_KEY_FLAG
+flagMultipleKey = FieldFlags #const MULTIPLE_KEY_FLAG
+
+flagUnsigned, flagZeroFill, flagBinary, flagAutoIncrement :: FieldFlag
+flagUnsigned = FieldFlags #const UNSIGNED_FLAG
+flagZeroFill = FieldFlags #const ZEROFILL_FLAG
+flagBinary = FieldFlags #const BINARY_FLAG
+flagAutoIncrement = FieldFlags #const AUTO_INCREMENT_FLAG
+
+flagNumeric, flagNoDefaultValue :: FieldFlag
+flagNumeric = FieldFlags #const NUM_FLAG
+flagNoDefaultValue = FieldFlags #const NO_DEFAULT_VALUE_FLAG
+
+hasAllFlags :: FieldFlags -> FieldFlags -> Bool
+FieldFlags a `hasAllFlags` FieldFlags b = a .&. b == b
+{-# INLINE hasAllFlags #-}
+
+peekField :: Ptr Field -> IO Field
+peekField ptr = do
+  flags <- FieldFlags <$> (#peek MYSQL_FIELD, flags) ptr
+  Field
+   <$> peekS ((#peek MYSQL_FIELD, name)) ((#peek MYSQL_FIELD, name_length))
+   <*> peekS ((#peek MYSQL_FIELD, org_name)) ((#peek MYSQL_FIELD, org_name_length))
+   <*> peekS ((#peek MYSQL_FIELD, table)) ((#peek MYSQL_FIELD, table_length))
+   <*> peekS ((#peek MYSQL_FIELD, org_table)) ((#peek MYSQL_FIELD, org_table_length))
+   <*> peekS ((#peek MYSQL_FIELD, db)) ((#peek MYSQL_FIELD, db_length))
+   <*> peekS ((#peek MYSQL_FIELD, catalog)) ((#peek MYSQL_FIELD, catalog_length))
+   <*> (if flags `hasAllFlags` flagNoDefaultValue
+       then pure Nothing
+       else Just <$> peekS ((#peek MYSQL_FIELD, def)) ((#peek MYSQL_FIELD, def_length)))
+   <*> (uint <$> (#peek MYSQL_FIELD, length) ptr)
+   <*> (uint <$> (#peek MYSQL_FIELD, max_length) ptr)
+   <*> pure flags
+   <*> (uint <$> (#peek MYSQL_FIELD, decimals) ptr)
+   <*> (uint <$> (#peek MYSQL_FIELD, charsetnr) ptr)
+   <*> (toType <$> (#peek MYSQL_FIELD, type) ptr)
+ where
+   uint = fromIntegral :: CUInt -> Word
+   peekS :: (Ptr Field -> IO (Ptr Word8)) -> (Ptr Field -> IO CUInt)
+         -> IO ByteString
+   peekS peekPtr peekLen = do
+     p <- peekPtr ptr
+     l <- peekLen ptr
+     create (fromIntegral l) $ \d -> memcpy d p (fromIntegral l)
+
+instance Storable Field where
+    sizeOf _    = #{size MYSQL_FIELD}
+    alignment _ = alignment (undefined :: Ptr CChar)
+    peek = peekField
+
 -- | Execute an 'IO' action with signals used by GHC's runtime signals
 -- blocked.  The @mysqlclient@ C library does not correctly restart
 -- system calls if they are interrupted by signals, so many MySQL API
 foreign import ccall safe mysql_affected_rows
     :: Ptr MYSQL -> IO CULLong
 
+foreign import ccall unsafe mysql_store_result
+    :: Ptr MYSQL -> IO (Ptr MYSQL_RES)
+
+foreign import ccall unsafe mysql_free_result
+    :: Ptr MYSQL_RES -> IO ()
+
+foreign import ccall unsafe mysql_fetch_fields
+    :: Ptr MYSQL_RES -> IO (Ptr Field)
+
 foreign import ccall safe mysql_real_escape_string
     :: Ptr MYSQL -> CString -> CString -> CULong -> IO CULong
 
 
   build-depends:
     base       < 5,
-    bytestring >= 0.9 && < 1.0
+    bytestring >= 0.9 && < 1.0,
+    containers
 
   ghc-options: -Wall
   if impl(ghc >= 6.8)
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.