Commits

Grzegorz Chrupała committed 306ee5f

Workaround for WriterT retaining input, preventing streaming.

  • Participants
  • Parent commits eccf4bf

Comments (0)

Files changed (4)

File colada/Colada/WordClass.hs

+
 {-# LANGUAGE 
    OverloadedStrings  
  , FlexibleInstances
 where       
   
 -- Standard libraries  
+import qualified Data.Text.Lazy.IO          as Text
 import qualified Data.Text.Lazy             as Text
+import qualified Data.Text.Lazy.Builder     as Text
+import qualified Data.Text.Lazy.Builder.Int as Text
 import qualified Data.Text.Lazy.Encoding    as Text
 import qualified Data.Vector                as V
 import qualified Data.Vector.Generic        as G
 import Prelude                              hiding ((.), exponent)
 import Control.Category    ((.))
 import Control.Applicative ((<$>))
-
+import qualified System.IO.Unsafe as Unsafe
 -- Third party modules  
 import qualified Control.Monad.Atom  as Atom
 import qualified NLP.CoNLL  as CoNLL
                  
 -- | @learn options xs@ runs the LDA Gibbs sampler for word classes with
 -- @options@ on sentences @xs@, and returns the resulting model
-learn :: Options -> [CoNLL.Sentence] -> (WordClass, [V.Vector LDA.D])
+learn :: Options -> [CoNLL.Sentence] -> WordClass
 learn opts xs = 
   let ((sbs_init, sbs_rest), atomTabD, atomTabW) = 
         Symbols.runSymbols prepare Symbols.empty Symbols.empty
         rest <- prepareData (get batchSize opts)
                                    (get repeats opts) 
                                    (get featIds opts) 
-                                    xs_rest
+                xs_rest
         return (ini, rest)
       best = V.map U.maxIndex
-      sampler :: WriterT [V.Vector LDA.D] (LST.ST s) LDA.Finalized
+      formatLabeling = Text.unlines . V.toList 
+                          . V.map (Text.toLazyText . Text.decimal)
+                          -- FIXME: workaround for WriterT retaining input
+      out ls = Unsafe.unsafePerformIO $ Text.putStrLn . formatLabeling 
+               $ ls
       sampler = do         
-        m <- st $ LDA.initial (U.singleton (get seed opts)) 
+        m <- LDA.initial (U.singleton (get seed opts)) 
                          (get topicNum opts)
                          (get alphasum opts)
                          (get beta opts)
                          (get exponent opts)
         let loop t z i = do
-              r <- st $ Trav.forM z $ \b -> do
+              r <- Trav.forM z $ \b -> do
                      Trav.forM b $ \s -> do
                        LDA.pass t m s
               M.when (get progressive opts && i == 1) $ do
-                let b = V.head z  
-                Fold.forM_ b $ \s -> do    
-                  ls <- st $ V.mapM (interpWordClasses m (get lambda opts)) s
-                  tell [best ls]
+                  let b = V.head z  
+                  Fold.forM_ b $ \s -> do    
+                    ls <- V.mapM (interpWordClasses m (get lambda opts)) s
+                    out (best ls) `seq` return ()
               return $! r
-        -- Initialize with batch sampler on prefix sbs_init     
+        -- -- Initialize with batch sampler on prefix sbs_init     
         Fold.forM_ sbs_init $ \sb -> do 
           Fold.foldlM (loop 1) sb [1..get initPasses opts] 
         -- Continue sampling
         Fold.forM_ (zip [1..] sbs_rest) $ \(t,sb) -> do
           Fold.foldlM (loop t) sb [1..get passes opts]
-        st $ LDA.finalize m    
-      (lda, labeled) = LST.runST (runWriterT sampler)
-  in (WordClass lda atomTabD atomTabW opts, labeled)
+        LDA.finalize m    
+      lda = ST.runST sampler
+  in WordClass lda atomTabD atomTabW opts
 
 type Symb  = Symbols.Symbols (U.Vector Char) (U.Vector Char)
 type Sent  = V.Vector LDA.Doc

File colada/colada.cabal

 -- The package version. See the Haskell package versioning policy
 -- (http://www.haskell.org/haskellwiki/Package_versioning_policy) for
 -- standards guiding when and how versions should be incremented.
-Version:             0.2.4
+Version:             0.3.0
 
 -- A short (one-line) description of the package.
 Synopsis:            Colada implements incremental word class class induction using online LDA
                , bytestring >= 0.9
                , vector-algorithms >= 0.5
                , mtl >= 2.0
-               , swift-lda >= 0.3 && <= 0.4
+               , swift-lda >= 0.4 && <= 0.5
   -- Modules not exported by this package.
   Other-modules: Colada.WordClass
                , Colada.Features

File colada/colada.hs

         $ ss  
     Learn { _options = o , _modelPath = p } -> do
       ss <- CoNLL.parse `fmap` Text.getContents
-      let (m, ls) = C.learn o ss
+      let m = C.learn o ss
       if (L.get C.progressive o)     
-        then do Text.putStr . Text.unlines . map formatLabeling $ ls
+        then do return () --Text.putStr . Text.unlines . map formatLabeling $ ls
         else do Text.putStr . C.summary $ m    
       BS.writeFile p . Serialize.encode $ m      
 

File swift-lda/swift-lda.cabal

 -- The package version. See the Haskell package versioning policy
 -- (http://www.haskell.org/haskellwiki/Package_versioning_policy) for
 -- standards guiding when and how versions should be incremented.
-Version:             0.3.1
+Version:             0.4.0
 
 -- A short (one-line) description of the package.
 Synopsis:            Online Latent Dirichlet Allocation