Commits

Bryan O'Sullivan  committed 57426e5

Add support for compression of lazy bytestrings.

  • Participants
  • Parent commits 9d0e1f9

Comments (0)

Files changed (8)

File Codec/Compression/Snappy.hs

     , decompress
     ) where
 
-import Control.Monad (unless)
+import Codec.Compression.Snappy.Internal (maxCompressedLength)
+import Control.Monad (when)
 import Data.ByteString.Internal (ByteString(..), mallocByteString)
 import Data.Word (Word8)
-import Foreign.C.Types (CSize)
+import Foreign.C.Types (CInt, CSize)
 import Foreign.ForeignPtr (withForeignPtr)
 import Foreign.Marshal.Alloc (alloca)
 import Foreign.Marshal.Utils (with)
 -- | Compress data into the Snappy format.
 compress :: ByteString -> ByteString
 compress bs@(PS sfp off len) = unsafePerformIO $ do
-  let dlen0 = fromIntegral . c_MaxCompressedLength . fromIntegral $ len
+  let dlen0 = maxCompressedLength len
   dfp <- mallocByteString dlen0
   withForeignPtr sfp $ \sptr ->
     withForeignPtr dfp $ \dptr ->
   withForeignPtr sfp $ \sptr0 -> do
     let sptr = sptr0 `plusPtr` off
         len = fromIntegral slen
+    let check ok = when (ok == 0) $
+                   fail "Codec.Compression.Snappy.decompress: corrupt input"
     alloca $ \dlenPtr -> do
-      ok0 <- c_GetUncompressedLength sptr len dlenPtr
-      unless ok0 $ error "Codec.Compression.Snappy.decompress: corrupt input"
+      check =<< c_GetUncompressedLength sptr len dlenPtr
       dlen <- fromIntegral `fmap` peek dlenPtr
       dfp <- mallocByteString dlen
       withForeignPtr dfp $ \dptr -> do
-        ok1 <- c_RawUncompress sptr len dptr
-        unless ok1 $ error "Codec.Compression.Snappy.decompress: corrupt input"
+        check =<< c_RawUncompress sptr len dptr
         return (PS dfp 0 dlen)
 
-foreign import ccall unsafe "hs_snappy.h _hsnappy_MaxCompressedLength"
-    c_MaxCompressedLength :: CSize -> CSize
-
 foreign import ccall unsafe "hs_snappy.h _hsnappy_RawCompress"
     c_RawCompress :: Ptr a -> CSize -> Ptr Word8 -> Ptr CSize -> IO ()
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_GetUncompressedLength"
-    c_GetUncompressedLength :: Ptr a -> CSize -> Ptr CSize -> IO Bool
+    c_GetUncompressedLength :: Ptr a -> CSize -> Ptr CSize -> IO CInt
 
 foreign import ccall unsafe "hs_snappy.h _hsnappy_RawUncompress"
-    c_RawUncompress :: Ptr a -> CSize -> Ptr Word8 -> IO Bool
+    c_RawUncompress :: Ptr a -> CSize -> Ptr Word8 -> IO CInt

File Codec/Compression/Snappy/Internal.hs

+{-# LANGUAGE ForeignFunctionInterface #-}
+
+-- |
+-- Module:      Codec.Compression.Snappy
+-- Copyright:   (c) 2011 MailRank, Inc.
+-- License:     Apache
+-- Maintainer:  Bryan O'Sullivan <bos@mailrank.com>
+-- Stability:   experimental
+-- Portability: portable
+--
+-- This module provides fast, pure Haskell bindings to Google's
+-- Snappy compression and decompression library:
+-- <http://code.google.com/p/snappy/>
+--
+-- These functions operate on strict bytestrings, and thus use as much
+-- memory as both the entire compressed and uncompressed data.
+
+module Codec.Compression.Snappy.Internal
+    (
+      maxCompressedLength
+    ) where
+
+import Foreign.C.Types (CSize)
+
+maxCompressedLength :: Int -> Int
+maxCompressedLength = fromIntegral . c_MaxCompressedLength . fromIntegral
+{-# INLINE maxCompressedLength #-}
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_MaxCompressedLength"
+    c_MaxCompressedLength :: CSize -> CSize

File Codec/Compression/Snappy/Lazy.hsc

+{-# LANGUAGE BangPatterns, EmptyDataDecls, ForeignFunctionInterface #-}
+
+-- |
+-- Module:      Codec.Compression.Snappy
+-- Copyright:   (c) 2011 MailRank, Inc.
+-- License:     Apache
+-- Maintainer:  Bryan O'Sullivan <bos@mailrank.com>
+-- Stability:   experimental
+-- Portability: portable
+--
+-- This module provides fast, pure compression and decompression
+-- of Snappy data.
+
+module Codec.Compression.Snappy.Lazy
+    (
+      compress
+    , decompress
+    ) where
+
+#include "hs_snappy.h"
+
+import Codec.Compression.Snappy.Internal (maxCompressedLength)
+import qualified Codec.Compression.Snappy as S
+import Data.ByteString.Lazy.Internal
+import qualified Data.ByteString.Lazy as L
+import Data.Word (Word8)
+import Data.ByteString.Internal hiding (ByteString)
+import qualified Data.ByteString as B
+import Foreign.C.Types (CInt, CSize)
+import Foreign.Storable
+import Foreign.Ptr (Ptr, plusPtr)
+import Foreign.ForeignPtr (withForeignPtr)
+import System.IO.Unsafe (unsafePerformIO)
+import Foreign.Marshal.Array
+import Foreign.Marshal.Utils
+
+newtype BS = BS B.ByteString
+
+instance Storable BS where
+    sizeOf _    = (#size struct BS)
+    alignment _ = alignment (undefined :: Ptr CInt)
+    poke ptr (BS (PS fp off len)) = withForeignPtr fp $ \p -> do
+      (#poke struct BS, ptr) ptr (p `plusPtr` off)
+      (#poke struct BS, len) ptr len
+    {-# INLINE poke #-}
+
+-- | Compress data into the Snappy format.
+compress :: ByteString -> ByteString
+compress bs = unsafePerformIO $ do
+  let len = fromIntegral (L.length bs)
+  let dlen0 = maxCompressedLength len
+  dfp <- mallocByteString dlen0
+  withForeignPtr dfp $ \dptr -> do
+    let chunks = L.toChunks bs
+    withArray (map BS chunks) $ \chunkPtr ->
+      with (fromIntegral dlen0) $ \dlenPtr -> do
+        c_CompressChunks chunkPtr (fromIntegral (length chunks))
+                         (fromIntegral len) dptr dlenPtr
+        dlen <- fromIntegral `fmap` peek dlenPtr
+        if dlen == 0
+          then return Empty
+          else return (Chunk (PS dfp 0 dlen) Empty)
+
+-- | Decompress data in the Snappy format.
+--
+-- If the input is not compressed or is corrupt, an exception will be
+-- thrown.
+decompress :: ByteString -> ByteString
+decompress = L.fromChunks . (:[]) . S.decompress . B.concat . L.toChunks
+
+foreign import ccall unsafe "hs_snappy.h _hsnappy_CompressChunks"
+    c_CompressChunks :: Ptr BS -> CSize -> CSize -> Ptr Word8 -> Ptr CSize
+                     -> IO ()

File cbits/hs_snappy.cpp

 #include "hs_snappy.h"
 #include "snappy.h"
+#include "snappy-sinksource.h"
+
+using namespace snappy;
 
 size_t _hsnappy_MaxCompressedLength(size_t n)
 {
-  return snappy::MaxCompressedLength(n);
+  return MaxCompressedLength(n);
 }
 
 void _hsnappy_RawCompress(const char *input, size_t input_length,
 			  char *compressed, size_t *compressed_length)
 {
-  snappy::RawCompress(input, input_length, compressed, compressed_length);
+  RawCompress(input, input_length, compressed, compressed_length);
 }
 
-bool _hsnappy_GetUncompressedLength(const char *compressed,
-				    size_t compressed_length,
-				    size_t *result)
+int _hsnappy_GetUncompressedLength(const char *compressed,
+				   size_t compressed_length,
+				   size_t *result)
 {
-  return snappy::GetUncompressedLength(compressed, compressed_length, result);
+  return GetUncompressedLength(compressed, compressed_length, result);
 }
 
-bool _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
-			    char *uncompressed)
+int _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
+			   char *uncompressed)
 {
-  return snappy::RawUncompress(compressed, compressed_length, uncompressed);
+  return RawUncompress(compressed, compressed_length, uncompressed);
 }
+
+class BSSource : public Source 
+{
+public:
+  BSSource(BS *chunks, size_t nchunks, size_t left)
+    : chunks_(chunks), nchunks_(nchunks), cur_(chunks), left_(left) { }
+  
+  size_t Available() const { return left_; }
+  
+  const char *Peek(size_t *len) {
+    *len = cur_->len;
+    return cur_->ptr;
+  }
+
+  void Skip(size_t n) {
+    left_ -= n;
+    while (n >= cur_->len) {
+      n -= cur_->len;
+      cur_++;
+    }
+    if (n > 0) {
+      cur_->len -= n;
+      cur_->ptr += n;
+    }
+  }
+
+private:
+  BS *chunks_;
+  const int nchunks_;
+  BS *cur_;
+  size_t left_;
+};
+  
+void _hsnappy_CompressChunks(BS *chunks, size_t nchunks, size_t length,
+			     char *compressed, size_t *compressed_length)
+{
+  BSSource reader(chunks, nchunks, length);
+  UncheckedByteArraySink writer(compressed);
+
+  Compress(&reader, &writer);
+
+  *compressed_length = writer.CurrentDestination() - compressed;
+}

File include/hs_snappy.h

 {
 #endif
 
+struct BS {
+  const char *ptr;
+  size_t len;
+};
+    
 size_t _hsnappy_MaxCompressedLength(size_t);
 
 void _hsnappy_RawCompress(const char *input, size_t input_length,
 			  char *compressed, size_t *compressed_length);
 
-bool _hsnappy_GetUncompressedLength(const char *compressed,
-				    size_t compressed_length,
-				    size_t *result);
+int _hsnappy_GetUncompressedLength(const char *compressed,
+				   size_t compressed_length,
+				   size_t *result);
 
-bool _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
-			    char *uncompressed);
+int _hsnappy_RawUncompress(const char *compressed, size_t compressed_length,
+			   char *uncompressed);
+
+void _hsnappy_CompressChunks(struct BS *chunks, size_t count,
+			     size_t length, char *compressed,
+			     size_t *compressed_length);
 
 #ifdef __cplusplus
 }

File snappy.cabal

 
 library
   c-sources:       cbits/hs_snappy.cpp
+  cc-options: -g -O0
   include-dirs:    include
   extra-libraries: snappy stdc++
 
 
   exposed-modules:
     Codec.Compression.Snappy
+    Codec.Compression.Snappy.Lazy
+
+  other-modules:
+    Codec.Compression.Snappy.Internal
 
 source-repository head
   type:     git

File tests/Makefile

 ghc := ghc
+ghcflags := -threaded -O
 
 all: qc snappy
 
 qc: Properties.hs
-	$(ghc) --make -o $@ $^
+	$(ghc) $(ghcflags) --make -o $@ $^
 
 snappy: Snappy.hs
-	$(ghc) -O --make -o $@ $^
+	$(ghc) $(ghcflags) -O --make -o $@ $^
 
 clean:
 	-rm -f qc snappy *.o *.hi

File tests/Properties.hs

-import Codec.Compression.Snappy
+import qualified Codec.Compression.Snappy as B
+import qualified Codec.Compression.Snappy.Lazy as L
 import Test.Framework (defaultMain, testGroup)
 import Test.Framework.Providers.QuickCheck2 (testProperty)
 import Test.QuickCheck (Arbitrary(..))
 import qualified Data.ByteString as B
+import qualified Data.ByteString.Lazy as L
 
-roundtrip s = decompress (compress bs) == bs
+s_roundtrip s = B.decompress (B.compress bs) == bs
   where bs = B.pack s
 
+l_roundtrip s = L.decompress (L.compress bs) == bs
+  where bs = L.pack s
+
 main = defaultMain tests
 
 tests = [
-    testProperty "roundtrip" roundtrip
+    testProperty "s_roundtrip" s_roundtrip
+  , testProperty "l_roundtrip" l_roundtrip
   ]