Commits

Bryan O'Sullivan committed bec0b08

Simplify and centralize buffer overflow handling.

Comments (0)

Files changed (4)

Data/Text/ICU/Char.hsc

 import Data.Word (Word8)
 import Foreign.C.String (CString, peekCStringLen, withCString)
 import Foreign.C.Types (CInt)
-import Foreign.Marshal.Alloc (allocaBytes)
 import Foreign.Ptr (Ptr)
 import System.IO.Unsafe (unsafePerformIO)
 
 charName' choice c = fillString $ u_charName (fromIntegral (ord c)) choice
 
 fillString :: (CString -> Int32 -> Ptr UErrorCode -> IO Int32) -> String
-fillString act = unsafePerformIO $ loop 128
- where
-  loop !n = do
-    ret <- allocaBytes n $ \ptr -> do
-             ret <- handleOverflowError $ act ptr (fromIntegral n)
-             case ret of
-              Left overflow -> return (Left overflow)
-              Right r       -> Right `fmap` peekCStringLen (ptr,fromIntegral r)
-    either (loop . fromIntegral) return ret
+fillString act = unsafePerformIO $
+                 handleOverflowError 128 act (curry peekCStringLen)
 
 type UBlockCode = CInt
 type UCharDirection = CInt

Data/Text/ICU/Error/Internal.hsc

     ) where
 
 import Control.Exception (Exception, throwIO)
+import Data.Function (fix)
 import Foreign.Ptr (Ptr)
 import Foreign.Marshal.Alloc (alloca)
 import Foreign.Marshal.Utils (with)
+import Foreign.Marshal.Array (allocaArray)
 import Data.Int (Int32)
 import Data.Typeable (Typeable)
 import Foreign.C.String (CString, peekCString)
                        throwOnError =<< peek errPtr
                        return ret
 
-handleOverflowError :: (Ptr UErrorCode -> IO a) -> IO (Either a a)
-{-# INLINE handleOverflowError #-}
-handleOverflowError action =
-    with 0 $ \uerrPtr -> do
-      ret <- action uerrPtr
+-- | Deal with ICU functions that report a buffer overflow error if we
+-- give them an insufficiently large buffer.  Our first call will
+-- report a buffer overflow, in which case we allocate a correctly
+-- sized buffer and try again.
+handleOverflowError :: (Storable a) =>
+                       Int
+                    -- ^ Initial guess at buffer size.
+                    -> (Ptr a -> Int32 -> Ptr UErrorCode -> IO Int32)
+                    -- ^ Function that retrieves data.
+                    -> (Ptr a -> Int -> IO b)
+                    -- ^ Function that fills destination buffer if no
+                    -- overflow occurred.
+                    -> IO b
+handleOverflowError guess fill retrieve =
+  alloca $ \uerrPtr -> flip fix guess $ \loop n ->
+    (either (loop . fromIntegral) return =<<) . allocaArray n $ \ptr -> do
+      poke uerrPtr 0
+      ret <- fill ptr (fromIntegral n) uerrPtr
       err <- peek uerrPtr
-      if err > 0
-        then if err == #const U_BUFFER_OVERFLOW_ERROR
-             then return (Left ret)
-             else throwIO (ICUError err)
-        else return (Right ret)
+      case undefined of
+        _| err == (#const U_BUFFER_OVERFLOW_ERROR)
+                     -> return (Left ret)
+         | err > 0   -> throwIO (ICUError err)
+         | otherwise -> Right `fmap` retrieve ptr (fromIntegral ret)
 
 handleParseError :: (ICUError -> Bool)
                  -> (Ptr UParseError -> Ptr UErrorCode -> IO a) -> IO a

Data/Text/ICU/Normalize.hsc

 #include <unicode/uchar.h>
 #include <unicode/unorm.h>
 
-import Control.Exception (throwIO)
-import Control.Monad (when)
 import Data.Text (Text)
 import Data.Text.Foreign (fromPtr, useAsPtr)
-import Data.Text.ICU.Error (u_BUFFER_OVERFLOW_ERROR)
-import Data.Text.ICU.Error.Internal (UErrorCode, isFailure, handleError, withError)
+import Data.Text.ICU.Error.Internal (UErrorCode, handleError, handleOverflowError)
 import Data.Text.ICU.Internal (UBool, UChar, asBool, asOrdering)
 import Data.Text.ICU.Normalize.Internal (UNormalizationCheckResult, toNCR)
 import Data.Typeable (Typeable)
 import Data.Int (Int32)
 import Data.Word (Word32)
 import Foreign.C.Types (CInt)
-import Foreign.Marshal.Array (allocaArray)
-import Foreign.Ptr (Ptr)
+import Foreign.Ptr (Ptr, castPtr)
 import System.IO.Unsafe (unsafePerformIO)
 import Prelude hiding (compare)
 import Data.List (foldl')
 normalize mode t = unsafePerformIO . useAsPtr t $ \sptr slen ->
   let slen' = fromIntegral slen
       mode' = toNM mode
-      loop dlen =
-        (either loop return =<<) .
-        allocaArray dlen $ \dptr -> do
-          (err, newLen) <- withError $
-              unorm_normalize sptr slen' mode' 0 dptr (fromIntegral dlen)
-          when (isFailure err && err /= u_BUFFER_OVERFLOW_ERROR) $
-            throwIO err
-          let newLen' = fromIntegral newLen
-          if newLen' > dlen
-            then return (Left newLen')
-            else Right `fmap` fromPtr dptr (fromIntegral newLen')
-  in loop (fromIntegral slen)
+  in handleOverflowError (fromIntegral slen)
+     (\dptr dlen -> unorm_normalize sptr slen' mode' 0 dptr (fromIntegral dlen))
+     (\dptr dlen -> fromPtr (castPtr dptr) (fromIntegral dlen))
     
       
 -- | Perform an efficient check on a string, to quickly determine if

Data/Text/ICU/Text.hs

 import Data.Word (Word32)
 import Foreign.C.String (CString)
 import Foreign.Marshal.Array (allocaArray)
-import Foreign.Ptr (Ptr)
+import Foreign.Ptr (Ptr, castPtr)
 import System.IO.Unsafe (unsafePerformIO)
 
 -- $case
 caseMap :: CaseMapper -> LocaleName -> Text -> Text
 caseMap mapFn loc s = unsafePerformIO .
   withLocaleName loc $ \locale ->
-    useAsPtr s $ \sptr slen -> do
-      let go len = do
-            ret <- allocaArray len $ \dptr -> do
-                  ret <- handleOverflowError $
-                         mapFn dptr (fromIntegral len) sptr
-                                    (fromIntegral slen) locale
-                  case ret of
-                    Left overflow -> return (Left overflow)
-                    Right n       -> Right `fmap` fromPtr dptr (fromIntegral n)
-            either (go . fromIntegral) return ret
-      go (fromIntegral slen)
+    useAsPtr s $ \sptr slen ->
+      handleOverflowError (fromIntegral slen)
+      (\dptr dlen -> mapFn dptr dlen sptr (fromIntegral slen) locale)
+      (\dptr dlen -> fromPtr (castPtr dptr) (fromIntegral dlen))
 
 -- | Lowercase the characters in a string.
 --
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.