Commits

Bryan O'Sullivan  committed fe0f880

Handle ?/value mismatch properly.

  • Participants
  • Parent commits 5eaa875

Comments (0)

Files changed (1)

File Database/MySQL/Simple.hs

+{-# LANGUAGE DeriveDataTypeable #-}
+
 module Database.MySQL.Simple
     (
-      execute
+      FormatError(fmtMessage, fmtQuery, fmtParams)
+    , Only(..)
+    , execute
     , query
     , query_
     , formatQuery
 import Blaze.ByteString.Builder (fromByteString, toByteString)
 import Control.Applicative ((<$>), pure)
 import Control.DeepSeq (NFData(..))
+import Control.Exception (Exception, throw)
 import Control.Monad.Fix (fix)
 import Data.ByteString (ByteString)
 import Data.Int (Int64)
-import Data.Monoid (mappend, mempty)
+import Data.Monoid (mappend)
+import Data.Typeable (Typeable)
 import Database.MySQL.Base (Connection)
 import Database.MySQL.Simple.Param (Action(..), inQuotes)
 import Database.MySQL.Simple.QueryParams (QueryParams(..))
 import Database.MySQL.Simple.QueryResults (QueryResults(..))
-import Database.MySQL.Simple.Types (Query(..))
+import Database.MySQL.Simple.Types (Only(..), Query(..))
 import qualified Data.ByteString.Char8 as B
 import qualified Database.MySQL.Base as Base
 
+data FormatError = FormatError {
+      fmtMessage :: String
+    , fmtQuery :: Query
+    , fmtParams :: [ByteString]
+    } deriving (Eq, Show, Typeable)
+
+instance Exception FormatError
+
 formatQuery :: QueryParams q => Connection -> Query -> q -> IO ByteString
-formatQuery conn (Query template) qs
-    | '?' `B.notElem` template = return template
-    | otherwise =
-        toByteString . zipParams (split template) <$> mapM sub (renderParams qs)
-  where sub (Plain b)  = pure b
+formatQuery conn q@(Query template) qs
+    | null xs && '?' `B.notElem` template = return template
+    | otherwise = toByteString . zipParams (split template) <$> mapM sub xs
+  where xs = renderParams qs
+        sub (Plain b)  = pure b
         sub (Escape s) = (inQuotes . fromByteString) <$> Base.escape conn s
-        split q = fromByteString h : if B.null t then [] else split (B.tail t)
-            where (h,t) = B.break (=='?') q
+        split s = fromByteString h : if B.null t then [] else split (B.tail t)
+            where (h,t) = B.break (=='?') s
         zipParams (t:ts) (p:ps) = t `mappend` p `mappend` zipParams ts ps
-        zipParams [] []         = mempty
-        zipParams [] _ = fmtError "more parameters than '?' characters"
-        zipParams _ [] = fmtError "more '?' characters than parameters"
+        zipParams [t] []        = t
+        zipParams _ _ = fmtError (show (B.count '?' template) ++
+                                  " '?' characters, but " ++
+                                  show (length xs) ++ " parameters") q xs
 
 execute :: (QueryParams q) => Connection -> Query -> q -> IO Int64
 execute conn template qs = do
           _  -> let c = convertResults fs row
                 in rnf c `seq` loop (c:acc)
 
-fmtError :: String -> a
-fmtError msg = error $ "Database.MySQL.formatQuery: " ++ msg
+fmtError :: String -> Query -> [Action] -> a
+fmtError msg q xs = throw FormatError {
+                      fmtMessage = msg
+                    , fmtQuery = q
+                    , fmtParams = map twiddle xs
+                    }
+  where twiddle (Plain b)  = toByteString b
+        twiddle (Escape s) = s