Source

text / Data / Text / Encoding / Fusion.hs

Diff from to

File Data/Text/Encoding/Fusion.hs

-{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE BangPatterns, Rank2Types #-}
 
 -- |
 -- Module      : Data.Text.Encoding.Fusion
 import Data.ByteString as B
 import Data.ByteString.Internal (ByteString(..), mallocByteString, memcpy)
 import Data.Text.Fusion (Step(..), Stream(..))
+import Data.Text.Encoding.Error
 import Data.Text.Encoding.Fusion.Common
 import Data.Text.UnsafeChar (unsafeChr, unsafeChr8, unsafeChr32)
 import Data.Text.UnsafeShift (shiftL)
 
 -- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using UTF-8
 -- encoding.
-streamUtf8 :: ByteString -> Stream Char
-streamUtf8 bs = Stream next 0 l
+streamUtf8 :: OnDecodeError -> ByteString -> Stream Char
+streamUtf8 onErr bs = Stream next 0 l
     where
       l = B.length bs
       {-# INLINE next #-}
           | i+1 < l && U8.validate2 x1 x2 = Yield (U8.chr2 x1 x2) (i+2)
           | i+2 < l && U8.validate3 x1 x2 x3 = Yield (U8.chr3 x1 x2 x3) (i+3)
           | i+3 < l && U8.validate4 x1 x2 x3 x4 = Yield (U8.chr4 x1 x2 x3 x4) (i+4)
-          | otherwise = encodingError "UTF-8"
+          | otherwise = decodeError "streamUtf8" "UTF-8" onErr mx (i+1)
           where
+            mx = if i >= l then Nothing else Just x1
             x1 = idx i
             x2 = idx (i + 1)
             x3 = idx (i + 2)
 
 -- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using little
 -- endian UTF-16 encoding.
-streamUtf16LE :: ByteString -> Stream Char
-streamUtf16LE bs = Stream next 0 l
+streamUtf16LE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf16LE onErr bs = Stream next 0 l
     where
       l = B.length bs
       {-# INLINE next #-}
           | i >= l                         = Done
           | i+1 < l && U16.validate1 x1    = Yield (unsafeChr x1) (i+2)
           | i+3 < l && U16.validate2 x1 x2 = Yield (U16.chr2 x1 x2) (i+4)
-          | otherwise = encodingError "UTF-16LE"
+          | otherwise = decodeError "streamUtf16LE" "UTF-16LE" onErr Nothing (i+1)
           where
             x1    = idx i       + (idx (i + 1) `shiftL` 8)
             x2    = idx (i + 2) + (idx (i + 3) `shiftL` 8)
 
 -- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using big
 -- endian UTF-16 encoding.
-streamUtf16BE :: ByteString -> Stream Char
-streamUtf16BE bs = Stream next 0 l
+streamUtf16BE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf16BE onErr bs = Stream next 0 l
     where
       l = B.length bs
       {-# INLINE next #-}
           | i >= l                         = Done
           | i+1 < l && U16.validate1 x1    = Yield (unsafeChr x1) (i+2)
           | i+3 < l && U16.validate2 x1 x2 = Yield (U16.chr2 x1 x2) (i+4)
-          | otherwise = encodingError "UTF16-BE"
+          | otherwise = decodeError "streamUtf16BE" "UTF-16BE" onErr Nothing (i+1)
           where
             x1    = (idx i `shiftL` 8)       + idx (i + 1)
             x2    = (idx (i + 2) `shiftL` 8) + idx (i + 3)
 
 -- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using big
 -- endian UTF-32 encoding.
-streamUtf32BE :: ByteString -> Stream Char
-streamUtf32BE bs = Stream next 0 l
+streamUtf32BE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf32BE onErr bs = Stream next 0 l
     where
       l = B.length bs
       {-# INLINE next #-}
       next i
           | i >= l                    = Done
           | i+3 < l && U32.validate x = Yield (unsafeChr32 x) (i+4)
-          | otherwise                 = encodingError "UTF-32BE"
+          | otherwise = decodeError "streamUtf32BE" "UTF-32BE" onErr Nothing (i+1)
           where
             x     = shiftL x1 24 + shiftL x2 16 + shiftL x3 8 + x4
             x1    = idx i
 
 -- | /O(n)/ Convert a 'ByteString' into a 'Stream Char', using little
 -- endian UTF-32 encoding.
-streamUtf32LE :: ByteString -> Stream Char
-streamUtf32LE bs = Stream next 0 l
+streamUtf32LE :: OnDecodeError -> ByteString -> Stream Char
+streamUtf32LE onErr bs = Stream next 0 l
     where
       l = B.length bs
       {-# INLINE next #-}
       next i
           | i >= l                    = Done
           | i+3 < l && U32.validate x = Yield (unsafeChr32 x) (i+4)
-          | otherwise                 = encodingError "UTF-32LE"
+          | otherwise = decodeError "streamUtf32LE" "UTF-32LE" onErr Nothing (i+1)
           where
             x     = shiftL x4 24 + shiftL x3 16 + shiftL x2 8 + x1
             x1    = idx i
                   memcpy dest' src' (fromIntegral srcLen)
           return dest
 
-encodingError :: String -> a
-encodingError encoding =
-    error $ "Data.Text.Encoding.Fusion: Bad " ++ encoding ++ " stream"
+decodeError :: forall s. String -> String -> OnDecodeError -> Maybe Word8
+            -> s -> Step s Char
+decodeError func kind onErr mb i =
+    case onErr desc mb of
+      Nothing -> Skip i
+      Just c  -> Yield c i
+    where desc = "Data.Text.Encoding.Fusion." ++ func ++ ": Invalid " ++
+                 kind ++ " stream"