Anonymous avatar Anonymous committed 37cbaef

use more specific type, handle other date/time types, fix getTables and describeTable

Comments (0)

Files changed (1)

Database/HDBC/MySQL/Connection.hsc

 import Foreign.C
 import qualified Foreign.Concurrent
 import qualified Data.ByteString as B
-import Data.ByteString.UTF8 (fromString)
+import Data.ByteString.UTF8 (fromString, toString)
 import Data.List (isPrefixOf)
 import Data.Time
 import Data.Time.Clock.POSIX
   bindOfSqlValue' (8::Int) buf_ #{const MYSQL_TYPE_LONGLONG} Unsigned
 
 bindOfSqlValue (Types.SqlEpochTime epoch) = do
-  let t = utcToMysqlTime $ posixSecondsToUTCTime (fromIntegral epoch)
-  buf_ <- new t
-  bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_DATETIME} Signed
-      where utcToMysqlTime :: UTCTime -> MYSQL_TIME
-            utcToMysqlTime (UTCTime day difftime) =
-                let (y, m, d) = toGregorian day
-                    t  = floor $ (realToFrac difftime :: Double)
-                    h  = t `div` 3600
-                    mn = t `div` 60 `mod` 60
-                    s  = t `mod` 60
-                in MYSQL_TIME (fromIntegral y) (fromIntegral m) (fromIntegral d) h mn s
+  bindOfSqlValue $ Types.SqlPOSIXTime $ fromIntegral epoch
 
 bindOfSqlValue (Types.SqlTimeDiff n) = do
   let h  = fromIntegral $ n `div` 3600
   buf_ <- new t
   bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_TIME} Signed
 
-bindOfSqlValue (Types.SqlLocalDate _) =
-    error "SqlLocalDate: bind type not implemented"
+bindOfSqlValue (Types.SqlLocalDate day) = do
+  let (y, m, d) = toGregorian day
+      t         = MYSQL_TIME (fromIntegral y) (fromIntegral m) (fromIntegral d) 0 0 0
+  buf_ <- new t
+  bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_DATE} Signed
 
-bindOfSqlValue (Types.SqlLocalTimeOfDay _) =
-    error "SqlLocalTimeOfDay: bind type not implemented"
+bindOfSqlValue (Types.SqlLocalTimeOfDay time) = do
+  let h  = fromIntegral $ todHour time
+      mn = fromIntegral $ todMin time
+      s  = floor $ todSec time
+      t  = MYSQL_TIME 0 0 0 h mn s
+  buf_ <- new t
+  bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_TIME} Signed
 
-bindOfSqlValue (Types.SqlZonedLocalTimeOfDay _ _) =
-    error "SqlZonedLocalTimeOfDay: bind type not implemented"
+bindOfSqlValue (Types.SqlZonedLocalTimeOfDay t _) =
+  bindOfSqlValue $ Types.SqlLocalTimeOfDay t
 
-bindOfSqlValue (Types.SqlLocalTime _) =
-    error "SqlLocalTime: bind type not implemented"
+bindOfSqlValue (Types.SqlLocalTime (LocalTime day time)) = do
+  let (y, m, d) = toGregorian day
+      h         = fromIntegral $ todHour time
+      mn        = fromIntegral $ todMin time
+      s         = floor $ todSec time
+      t         = MYSQL_TIME (fromIntegral y) (fromIntegral m) (fromIntegral d) h mn s
+  buf_ <- new t
+  bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_DATETIME} Signed
 
-bindOfSqlValue (Types.SqlZonedTime _) =
-    error "SqlZonedTime: bind type not implemented"
+bindOfSqlValue (Types.SqlZonedTime t) =
+  bindOfSqlValue $ Types.SqlLocalTime $ zonedTimeToLocalTime t
 
-bindOfSqlValue (Types.SqlUTCTime _) =
-    error "SqlUTCTime: bind type not implemented"
+bindOfSqlValue (Types.SqlUTCTime (UTCTime day time)) = do
+  let (y, m, d) = toGregorian day
+      t         = floor $ (realToFrac time :: Double)
+      h         = t `div` 3600
+      mn        = t `div` 60 `mod` 60
+      s         = t `mod` 60
+      t'        = MYSQL_TIME (fromIntegral y) (fromIntegral m) (fromIntegral d) h mn s
+  buf_ <- new t'
+  bindOfSqlValue' (#{const sizeof(MYSQL_TIME)}::Int) buf_ #{const MYSQL_TYPE_DATETIME} Signed
 
-bindOfSqlValue (Types.SqlDiffTime _) =
-    error "SqlDiffTime: bind type not implemented"
+bindOfSqlValue (Types.SqlDiffTime t) =
+  bindOfSqlValue $ Types.SqlPOSIXTime t
 
-bindOfSqlValue (Types.SqlPOSIXTime _) =
-    error "SqlPOSIXtime: bind type not implemented"
+bindOfSqlValue (Types.SqlPOSIXTime t) =
+  bindOfSqlValue $ Types.SqlUTCTime $ posixSecondsToUTCTime t
 
 -- A nasty helper function that cuts down on the boilerplate a bit.
 bindOfSqlValue' :: (Integral a, Storable b) => a -> Ptr b -> CInt -> Signedness -> IO MYSQL_BIND
 boundType #{const MYSQL_TYPE_NEWDECIMAL} 0 = #{const MYSQL_TYPE_LONGLONG}
 boundType #{const MYSQL_TYPE_NEWDECIMAL} _ = #{const MYSQL_TYPE_DOUBLE}
 boundType #{const MYSQL_TYPE_FLOAT}      _ = #{const MYSQL_TYPE_DOUBLE}
-boundType #{const MYSQL_TYPE_DATE}       _ = #{const MYSQL_TYPE_DATETIME}
-boundType #{const MYSQL_TYPE_TIMESTAMP}  _ = #{const MYSQL_TYPE_DATETIME}
-boundType #{const MYSQL_TYPE_NEWDATE}    _ = #{const MYSQL_TYPE_DATETIME}
 boundType #{const MYSQL_TYPE_BLOB}       _ = #{const MYSQL_TYPE_VAR_STRING}
 boundType t                              _ = t
 
 -- Returns the amount of storage required for a particular result
 -- type.
 boundSize :: CInt -> CULong -> CULong
-boundSize #{const MYSQL_TYPE_LONG}     _ = 4
-boundSize #{const MYSQL_TYPE_DOUBLE}   _ = 8
-boundSize #{const MYSQL_TYPE_DATETIME} _ = #{const sizeof(MYSQL_TIME)}
-boundSize _                            n = n
+boundSize #{const MYSQL_TYPE_LONG}      _ = 4
+boundSize #{const MYSQL_TYPE_DOUBLE}    _ = 8
+boundSize #{const MYSQL_TYPE_DATETIME}  _ = #{const sizeof(MYSQL_TIME)}
+boundSize #{const MYSQL_TYPE_TIME}      _ = #{const sizeof(MYSQL_TIME)}
+boundSize #{const MYSQL_TYPE_NEWDATE}   _ = #{const sizeof(MYSQL_TIME)}
+boundSize #{const MYSQL_TYPE_DATE}      _ = #{const sizeof(MYSQL_TIME)}
+boundSize #{const MYSQL_TYPE_TIMESTAMP} _ = #{const sizeof(MYSQL_TIME)}
+boundSize _                             n = n
 
 -- Fetches a row from an executed statement and converts the cell
 -- values into a list of SqlValue types.
   if isNull == 0 then cellValue' else return Types.SqlNull
       where cellValue' = do
                    len <- peek $ bindLength bind
-                   let buftype = bindBufferType bind
-                       buf     = bindBuffer bind
-                   nonNullCellValue buftype buf len
+                   let buftype  = bindBufferType bind
+                       buf      = bindBuffer bind
+                       unsigned = bindIsUnsigned bind == 1
+                   nonNullCellValue buftype buf len unsigned
 
 -- Produces a single SqlValue from the binding's type and buffer
 -- pointer.  It assumes that the value is not null.
-nonNullCellValue :: CInt -> Ptr () -> CULong -> IO Types.SqlValue
+nonNullCellValue :: CInt -> Ptr () -> CULong -> Bool -> IO Types.SqlValue
 
-nonNullCellValue #{const MYSQL_TYPE_LONG} p _ = do
+nonNullCellValue #{const MYSQL_TYPE_LONG} p _ u = do
   n :: CInt <- peek $ castPtr p
-  return $ Types.SqlInteger (fromIntegral n)
+  return $ if u then Types.SqlWord32 (fromIntegral n)
+                else Types.SqlInt32 (fromIntegral n)
 
-nonNullCellValue #{const MYSQL_TYPE_LONGLONG} p _ = do
+nonNullCellValue #{const MYSQL_TYPE_LONGLONG} p _ u = do
   n :: CLLong <- peek $ castPtr p
-  return $ Types.SqlInteger (fromIntegral n)
+  return $ if u then Types.SqlWord64 (fromIntegral n)
+                else Types.SqlInt64 (fromIntegral n)
 
-nonNullCellValue #{const MYSQL_TYPE_DOUBLE} p _ = do
+nonNullCellValue #{const MYSQL_TYPE_DOUBLE} p _ _ = do
   n :: CDouble <- peek $ castPtr p
   return $ Types.SqlDouble (realToFrac n)
 
-nonNullCellValue #{const MYSQL_TYPE_VAR_STRING} p len =
+nonNullCellValue #{const MYSQL_TYPE_VAR_STRING} p len _ =
     B.packCStringLen ((castPtr p), fromIntegral len) >>= return . Types.SqlByteString
 
-nonNullCellValue #{const MYSQL_TYPE_DATETIME} p _ = do
+nonNullCellValue #{const MYSQL_TYPE_TIMESTAMP} p _ _ = do
   t :: MYSQL_TIME <- peek $ castPtr p
-  let epoch = (floor . toRational . utcTimeToPOSIXSeconds . mysqlTimeToUTC) t
-  return $ Types.SqlEpochTime epoch
+  let secs = (utcTimeToPOSIXSeconds . mysqlTimeToUTC) t
+  return $ Types.SqlPOSIXTime secs
       where mysqlTimeToUTC :: MYSQL_TIME -> UTCTime
             mysqlTimeToUTC (MYSQL_TIME y m d h mn s) =
                 -- XXX so, this is fine if the date we're getting back
                     time = s + mn * 60 + h * 3600
                 in UTCTime day (secondsToDiffTime $ fromIntegral time)
 
-nonNullCellValue #{const MYSQL_TYPE_TIME} p _ = do
+nonNullCellValue #{const MYSQL_TYPE_DATETIME} p _ _ = do
+  (MYSQL_TIME y m d h mn s) <- peek $ castPtr p
+  let date = fromGregorian (fromIntegral y) (fromIntegral m) (fromIntegral d)
+      time = TimeOfDay (fromIntegral h) (fromIntegral mn) (fromIntegral s)
+  return $ Types.SqlLocalTime (LocalTime date time)
+
+nonNullCellValue #{const MYSQL_TYPE_TIME} p _ _ = do
   (MYSQL_TIME _ _ _ h mn s) <- peek $ castPtr p
-  let secs = 3600 * h + 60 * mn + s
-  return $ Types.SqlTimeDiff (fromIntegral secs)
+  let time = TimeOfDay (fromIntegral h) (fromIntegral mn) (fromIntegral s)
+  return $ Types.SqlLocalTimeOfDay time
 
-nonNullCellValue t _ _ = return $ Types.SqlString ("unknown type " ++ show t)
+nonNullCellValue #{const MYSQL_TYPE_DATE} p _ _ = do
+  (MYSQL_TIME y m d _ _ _) <- peek $ castPtr p
+  let date = fromGregorian (fromIntegral y) (fromIntegral m) (fromIntegral d)
+  return $ Types.SqlLocalDate date
+
+nonNullCellValue #{const MYSQL_TYPE_NEWDATE} p _ _ = do
+  (MYSQL_TIME y m d _ _ _) <- peek $ castPtr p
+  let date = fromGregorian (fromIntegral y) (fromIntegral m) (fromIntegral d)
+  return $ Types.SqlLocalDate date
+
+nonNullCellValue t _ _ _ = return $ Types.SqlString ("unknown type " ++ show t)
 
 -- Cough up the column metadata for a field that's returned from a
 -- query.
   Types.finish stmt
   return $ map (fromSql . head) rows
       where fromSql :: Types.SqlValue -> String
-            fromSql (Types.SqlString s) = s
-            fromSql _                   = error "SHOW TABLES returned a table whose name wasn't a string"
+            fromSql (Types.SqlByteString s) = toString s
+            fromSql _                       = error "SHOW TABLES returned a table whose name wasn't a string"
 
 -- Describe a single table in the database by issuing a "DESCRIBE"
 -- statement and parsing the results.  (XXX this is sloppy right now;
   Types.finish stmt
   return $ map fromRow rows
       where fromRow :: [Types.SqlValue] -> (String, ColTypes.SqlColDesc)
-            fromRow ((Types.SqlString colname)
-                     :(Types.SqlString coltype)
-                     :(Types.SqlString nullAllowed):_) =
-                let sqlTypeId = typeIdOfString coltype
+            fromRow ((Types.SqlByteString colname)
+                     :(Types.SqlByteString coltype)
+                     :(Types.SqlByteString nullAllowed):_) =
+                let sqlTypeId = typeIdOfString $ toString coltype
                     -- XXX parse the column width and decimals, too!
-                    nullable = Just $ nullAllowed == "YES"
-                in (colname, ColTypes.SqlColDesc sqlTypeId Nothing Nothing Nothing nullable)
+                    nullable = Just $ toString nullAllowed == "YES"
+                in (toString colname, ColTypes.SqlColDesc sqlTypeId Nothing Nothing Nothing nullable)
 
             fromRow _ = error "DESCRIBE failed"
 
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.