Commits

Bryan O'Sullivan  committed d2781a9

Implement decompression of lazy bytestrings.

  • Participants
  • Parent commits c7d9665

Comments (0)

Files changed (7)

File Codec/Compression/Snappy.hs

     , decompress
     ) where
 
-import Codec.Compression.Snappy.Internal (maxCompressedLength)
-import Control.Monad (when)
+import Codec.Compression.Snappy.Internal (check, maxCompressedLength)
 import Data.ByteString.Internal (ByteString(..), mallocByteString)
-import Data.Word (Word8)
+import Data.Word (Word8, Word32)
 import Foreign.C.Types (CInt, CSize)
 import Foreign.ForeignPtr (withForeignPtr)
 import Foreign.Marshal.Alloc (alloca)
 
 -- | Compress data into the Snappy format.
 compress :: ByteString -> ByteString
-compress bs@(PS sfp off len) = unsafePerformIO $ do
+compress (PS sfp off len) = unsafePerformIO $ do
   let dlen0 = maxCompressedLength len
   dfp <- mallocByteString dlen0
   withForeignPtr sfp $ \sptr ->
   withForeignPtr sfp $ \sptr0 -> do
     let sptr = sptr0 `plusPtr` off
         len = fromIntegral slen
-    let check ok = when (ok == 0) $
-                   fail "Codec.Compression.Snappy.decompress: corrupt input"
     alloca $ \dlenPtr -> do
-      check =<< c_GetUncompressedLength sptr len dlenPtr
+      check "decompress" $ c_GetUncompressedLength sptr len dlenPtr
       dlen <- fromIntegral `fmap` peek dlenPtr
-      dfp <- mallocByteString dlen
-      withForeignPtr dfp $ \dptr -> do
-        check =<< c_RawUncompress sptr len dptr
-        return (PS dfp 0 dlen)
+      if dlen == 0
+        then return B.empty
+        else do
+          dfp <- mallocByteString dlen
+          withForeignPtr dfp $ \dptr -> do
+            check "decompress" $ c_RawUncompress sptr len dptr
+            return (PS dfp 0 dlen)
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_RawCompress"
     c_RawCompress :: Ptr a -> CSize -> Ptr Word8 -> Ptr CSize -> IO ()
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_GetUncompressedLength"
-    c_GetUncompressedLength :: Ptr a -> CSize -> Ptr CSize -> IO CInt
+    c_GetUncompressedLength :: Ptr a -> CSize -> Ptr Word32 -> IO CInt
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_RawUncompress"
     c_RawUncompress :: Ptr a -> CSize -> Ptr Word8 -> IO CInt

File Codec/Compression/Snappy/Internal.hs

 
 module Codec.Compression.Snappy.Internal
     (
-      maxCompressedLength
+      check
+    , maxCompressedLength
     ) where
 
+import Control.Monad (when)
 import Foreign.C.Types (CSize)
 
 maxCompressedLength :: Int -> Int
 maxCompressedLength = fromIntegral . c_MaxCompressedLength . fromIntegral
 {-# INLINE maxCompressedLength #-}
 
+check :: (Integral a) => String -> IO a -> IO ()
+check func act = do
+  ok <- act
+  when (ok == 0) . fail $ "Codec.Compression.Snappy." ++ func ++
+                          ": corrupt input "
+{-# INLINE check #-}
+
 foreign import ccall unsafe "hs_snappy.h _hsnappy_MaxCompressedLength"
     c_MaxCompressedLength :: CSize -> CSize

File Codec/Compression/Snappy/Lazy.hsc

 
 #include "hs_snappy.h"
 
-import Codec.Compression.Snappy.Internal (maxCompressedLength)
+import Codec.Compression.Snappy.Internal (check, maxCompressedLength)
+import Control.Exception (bracket)
 import Data.ByteString.Internal hiding (ByteString)
 import Data.ByteString.Lazy.Internal (ByteString(..))
-import Data.Word (Word8)
+import Data.Word (Word8, Word32)
 import Foreign.C.Types (CInt, CSize)
 import Foreign.ForeignPtr (touchForeignPtr, withForeignPtr)
+import Foreign.Marshal.Alloc (alloca)
 import Foreign.Marshal.Array (withArray)
 import Foreign.Marshal.Utils (with)
 import Foreign.Ptr (Ptr, plusPtr)
 import Foreign.Storable (Storable(..))
 import System.IO.Unsafe (unsafePerformIO)
-import qualified Codec.Compression.Snappy as S
 import qualified Data.ByteString as B
 import qualified Data.ByteString.Lazy as L
 
 newtype BS = BS B.ByteString
 
+data BSSource
+
 instance Storable BS where
     sizeOf _    = (#size struct BS)
     alignment _ = alignment (undefined :: Ptr CInt)
     poke ptr (BS (PS fp off len)) = withForeignPtr fp $ \p -> do
       (#poke struct BS, ptr) ptr (p `plusPtr` off)
+      (#poke struct BS, off) ptr (0::CSize)
       (#poke struct BS, len) ptr len
     {-# INLINE poke #-}
 
 -- | Compress data into the Snappy format.
 compress :: ByteString -> ByteString
-compress bs = unsafePerformIO $ do
-  let len = fromIntegral (L.length bs)
+compress bs = unsafePerformIO . withChunks bs $ \chunkPtr numChunks len -> do
   let dlen0 = maxCompressedLength len
   dfp <- mallocByteString dlen0
   withForeignPtr dfp $ \dptr -> do
-    let chunks = L.toChunks bs
-    withArray (map BS chunks) $ \chunkPtr ->
-      with (fromIntegral dlen0) $ \dlenPtr -> do
-        c_CompressChunks chunkPtr (fromIntegral (length chunks))
-                         (fromIntegral len) dptr dlenPtr
-        foldr (\(PS fp _ _) _ -> touchForeignPtr fp) (return ()) chunks
-        dlen <- fromIntegral `fmap` peek dlenPtr
-        if dlen == 0
-          then return Empty
-          else return (Chunk (PS dfp 0 dlen) Empty)
+    with (fromIntegral dlen0) $ \dlenPtr -> do
+      c_CompressChunks chunkPtr (fromIntegral numChunks)
+                       (fromIntegral len) dptr dlenPtr
+      dlen <- fromIntegral `fmap` peek dlenPtr
+      if dlen == 0
+        then return Empty
+        else return (Chunk (PS dfp 0 dlen) Empty)
 
 -- | Decompress data in the Snappy format.
 --
 -- If the input is not compressed or is corrupt, an exception will be
 -- thrown.
 decompress :: ByteString -> ByteString
-decompress = L.fromChunks . (:[]) . S.decompress . B.concat . L.toChunks
+decompress bs = unsafePerformIO . withChunks bs $ \chunkPtr numChunks len ->
+  bracket (c_NewSource chunkPtr (fromIntegral numChunks) (fromIntegral len))
+          c_DeleteSource $ \srcPtr -> do
+    alloca $ \dlenPtr -> do
+      check "Lazy.decompress" $ c_GetUncompressedLengthChunks srcPtr dlenPtr
+      dlen <- fromIntegral `fmap` peek dlenPtr
+      if dlen == 0
+        then return L.empty
+        else do
+          dfp <- mallocByteString dlen
+          withForeignPtr dfp $ \dptr -> do
+            check "Lazy.decompress" $ c_UncompressChunks srcPtr dptr
+            return (Chunk (PS dfp 0 dlen) Empty)
+
+withChunks :: ByteString -> (Ptr BS -> Int -> Int -> IO a) -> IO a
+withChunks bs act = do
+  let len = fromIntegral (L.length bs)
+  let chunks = L.toChunks bs
+  r <- withArray (map BS chunks) $ \chunkPtr ->
+       act chunkPtr (length chunks) len
+  foldr (\(PS fp _ _) _ -> touchForeignPtr fp) (return ()) chunks
+  return r
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_CompressChunks"
     c_CompressChunks :: Ptr BS -> CSize -> CSize -> Ptr Word8 -> Ptr CSize
                      -> IO ()
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_NewSource"
+    c_NewSource :: Ptr BS -> CSize -> CSize -> IO (Ptr BSSource)
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_DeleteSource"
+    c_DeleteSource :: Ptr BSSource -> IO ()
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_UncompressChunks"
+    c_UncompressChunks :: Ptr BSSource -> Ptr Word8 -> IO Int
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_GetUncompressedLengthChunks"
+    c_GetUncompressedLengthChunks :: Ptr BSSource -> Ptr Word32 -> IO Int

File cbits/hs_snappy.cpp

   return RawUncompress(compressed, compressed_length, uncompressed);
 }
 
-class BSSource : public Source 
+class BSSource : public Source
 {
 public:
-  BSSource(BS *chunks, size_t nchunks, size_t left)
-    : chunks_(chunks), nchunks_(nchunks), cur_(chunks), left_(left) { }
-  
+  BSSource(BS *chunks, size_t nchunks, size_t size)
+    : chunks_(chunks), nchunks_(nchunks), size_(size), cur_(chunks),
+      left_(size) { }
+
   size_t Available() const { return left_; }
-  
+
   const char *Peek(size_t *len) {
-    *len = cur_->len;
-    return cur_->ptr;
+    if (left_ > 0) {
+      *len = cur_->len - cur_->off;
+      return cur_->ptr + cur_->off;
+    } else {
+      *len = 0;
+      return NULL;
+    }
   }
 
   void Skip(size_t n) {
-    left_ -= n;
-    while (n >= cur_->len) {
-      n -= cur_->len;
-      cur_++;
+    if (n > 0) {
+      left_ -= n;
+      cur_->off += n;
+      if (cur_->off == cur_->len)
+	cur_++;
     }
-    if (n > 0) {
-      cur_->len -= n;
-      cur_->ptr += n;
-    }
+  }
+
+  void Rewind() {
+    left_ = size_;
+    cur_ = chunks_;
+    for (size_t i = 0; i < nchunks_ && chunks_[i].off > 0; i++)
+      chunks_[i].off = 0;
   }
 
 private:
   BS *chunks_;
-  const int nchunks_;
+  const size_t nchunks_;
+  const size_t size_;
   BS *cur_;
   size_t left_;
 };
-  
+
 void _hsnappy_CompressChunks(BS *chunks, size_t nchunks, size_t length,
 			     char *compressed, size_t *compressed_length)
 {
 
   *compressed_length = writer.CurrentDestination() - compressed;
 }
+
+BSSource *_hsnappy_NewSource(BS *chunks, size_t nchunks, size_t length)
+{
+  return new BSSource(chunks, nchunks, length);
+}
+
+void _hsnappy_DeleteSource(BSSource *src)
+{
+  delete src;
+}
+
+int _hsnappy_UncompressChunks(BSSource *reader, char *uncompressed)
+{
+  return RawUncompress(reader, uncompressed);
+}
+
+int _hsnappy_GetUncompressedLengthChunks(BSSource *reader, uint32_t *result)
+{
+  int n = GetUncompressedLength(reader, result);
+  reader->Rewind();
+  return n;
+}

File include/hs_snappy.h

 #define _hs_snappy_h
 
 #include <stddef.h>
+#include <stdint.h>
 
 #ifdef __cplusplus
-extern "C" 
+extern "C"
 {
 #endif
 
 struct BS {
   const char *ptr;
+  size_t off;
   size_t len;
 };
-    
+
 size_t _hsnappy_MaxCompressedLength(size_t);
 
 void _hsnappy_RawCompress(const char *input, size_t input_length,
 int _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
 			   char *uncompressed);
 
+struct BS;
+
 void _hsnappy_CompressChunks(struct BS *chunks, size_t count,
 			     size_t length, char *compressed,
 			     size_t *compressed_length);
 
+struct BSSource;
+
+struct BSSource *_hsnappy_NewSource(struct BS *chunks, size_t nchunks,
+				    size_t length);
+
+void _hsnappy_DeleteSource(struct BSSource *reader);
+
+int _hsnappy_UncompressChunks(struct BSSource *reader, char *uncompressed);
+
+int _hsnappy_GetUncompressedLengthChunks(struct BSSource *reader,
+					 uint32_t *result);
+
 #ifdef __cplusplus
 }
 #endif

File snappy.cabal

 
 library
   c-sources:       cbits/hs_snappy.cpp
-  cc-options: -g -O0
   include-dirs:    include
   extra-libraries: snappy stdc++
 
-  build-depends:     base < 5, bytestring
+  cc-options:      -Wall
+  ghc-options:     -Wall
+
+  build-depends:   base < 5, bytestring
   if impl(ghc >= 6.10)
-    build-depends:   base >= 4
+    build-depends: base >= 4
 
   exposed-modules:
     Codec.Compression.Snappy

File tests/Properties.hs

+{-# LANGUAGE FlexibleInstances #-}
+
+import Control.Applicative
 import qualified Codec.Compression.Snappy as B
 import qualified Codec.Compression.Snappy.Lazy as L
 import Test.Framework (defaultMain, testGroup)
 import qualified Data.ByteString as B
 import qualified Data.ByteString.Lazy as L
 
-s_roundtrip s = B.decompress (B.compress bs) == bs
-  where bs = B.pack s
+instance Arbitrary B.ByteString where
+    arbitrary = B.pack <$> arbitrary
 
-l_roundtrip s = L.decompress (L.compress bs) == bs
-  where bs = L.pack s
+instance Arbitrary L.ByteString where
+    arbitrary = rechunk <$> arbitrary <*> arbitrary
+
+s_roundtrip bs = B.decompress (B.compress bs) == bs
+
+newtype Compressed a = Compressed { compressed :: a }
+    deriving (Eq, Ord)
+
+instance Show a => Show (Compressed a)
+    where show (Compressed a) = "Compressed " ++ show a
+
+instance Arbitrary (Compressed B.ByteString) where
+    arbitrary = (Compressed . B.compress) <$> arbitrary
+
+compress_eq n bs = L.fromChunks [B.compress bs] == L.compress (rechunk n bs)
+decompress_eq n bs0 =
+    L.fromChunks [B.decompress bs] == L.decompress (rechunk n bs)
+  where bs = B.compress bs0
+
+rechunk :: Int -> B.ByteString -> L.ByteString
+rechunk n = L.fromChunks . go
+  where go bs | B.null bs = []
+              | otherwise = case B.splitAt ((n `mod` 63) + 1) bs of
+                              (x,y) -> x : go y
+
+t_rechunk n bs = L.fromChunks [bs] == rechunk n bs
+
+l_roundtrip bs = L.decompress (L.compress bs) == bs
 
 main = defaultMain tests
 
 tests = [
     testProperty "s_roundtrip" s_roundtrip
+  , testProperty "t_rechunk" t_rechunk
+  , testProperty "compress_eq" compress_eq
+  , testProperty "decompress_eq" decompress_eq
   , testProperty "l_roundtrip" l_roundtrip
   ]