Commits

Anonymous committed 3120c2e

add documentation, remove dependency on monad-primitive

Comments (0)

Files changed (2)

Statistics/Correlation/Kendall.hs

 {-# LANGUAGE BangPatterns, FlexibleContexts #-}
-
 -- |
 -- Module      : Statistics.Correlation.Kendall
--- Description : Kendall's τ
 --
--- Fast O(NlogN) implementation of Kendall's tau.
+-- Fast O(NlogN) implementation of
+-- <http://en.wikipedia.org/wiki/Kendall_tau_rank_correlation_coefficient Kendall's tau>.
+--
+-- This module implementes Kendall's tau form b which allows ties in the data.
+-- This is the same formula used by other statistical packages, e.g., R, matlab.
+--
+-- $$\tau = \frac{n_c - n_d}{\sqrt{(n_0 - n_1)(n_0 - n_2)}}$$
+--
+-- where $n_0 = n(n-1)/2$, $n_1 = number of pairs tied for the first quantify$,
+-- $n_2 = number of pairs tied for the second quantify$,
+-- $n_c = number of concordant pairs$, $n_d = number of discordant pairs$.
 
 module Statistics.Correlation.Kendall 
     ( kendall ) where
 import qualified Data.Vector.Generic as G
 import qualified Data.Vector.Generic.Mutable as GM
 import Data.Function
-import Control.Monad.Primitive
+import Data.Bits
 import Control.Monad.ST
-import Data.Bits
-import Data.PrimRef
+import Data.STRef
 
-kendall :: (Ord a, G.Vector v (a, a)) => v (a, a) -> Double
-{-# INLINABLE kendall #-}
-kendall xy' = runST $ do
+-- | /O(nlogn)/ Compute the Kendall's tau from a vector of paired data.
+-- Return NaN when number of pairs <= 1.
+kendall :: (Ord a, Ord b, G.Vector v (a, b)) => v (a, b) -> Double
+kendall xy'
+  | G.length xy' <= 1 = 0/0
+  | otherwise  = runST $ do
     xy <- G.thaw xy'
     let n = GM.length xy
         n_0 = (fromIntegral n * (fromIntegral n-1)) `shiftR` 1 :: Integer
-    n_dis <- newPrimRef 0
+    n_dis <- newSTRef 0
     I.sort xy
-    equalX <- numOfEqualBy ((==) `on` fst) xy
+    equalX <- numOfTiesBy ((==) `on` fst) xy
     tmp <- GM.new n
     mergeSort (compare `on` snd) xy tmp n_dis
-    equalY <- numOfEqualBy ((==) `on` snd) xy
-    n_d <- readPrimRef n_dis
+    equalY <- numOfTiesBy ((==) `on` snd) xy
+    n_d <- readSTRef n_dis
     let nu = n_0 - n_d - equalX - equalY - n_d
         de = (n_0 - equalX) * (n_0 - equalY)
     return $ fromIntegral nu / (sqrt.fromIntegral) de
+{-# INLINE kendall #-}
 
-numOfEqualBy :: (PrimMonad m, GM.MVector v a)
-             => (a -> a -> Bool) -> v (PrimState m) a -> m Integer
-{-# INLINE numOfEqualBy #-}
-numOfEqualBy f xs = do
-    count <- newPrimRef (0::Integer)
+numOfTiesBy :: GM.MVector v a
+             => (a -> a -> Bool) -> v s a -> ST s Integer
+numOfTiesBy f xs = do
+    count <- newSTRef (0::Integer)
     loop count (1::Int) (0::Int)
-    readPrimRef count
+    readSTRef count
     where
         n = GM.length xs
-        loop c !acc !i | i >= n - 1 = modifyPrimRef' c (+ g acc)
+        loop c !acc !i | i >= n - 1 = modifySTRef' c (+ g acc)
                        | otherwise = do
                            x1 <- GM.unsafeRead xs i
                            x2 <- GM.unsafeRead xs (i+1)
                            if f x1 x2
                               then loop c (acc+1) (i+1)
-                              else modifyPrimRef' c (+ g acc) >> loop c 1 (i+1)
+                              else modifySTRef' c (+ g acc) >> loop c 1 (i+1)
         g x = fromIntegral ((x * (x - 1)) `shiftR` 1)
+{-# INLINE numOfTiesBy #-}
 
--- Adapted from vector-algorithm
-mergeSort :: (PrimMonad m, GM.MVector v e)
+-- Implementation of Knight's merge sort (adapted from vector-algorithm). This
+-- function is used to count the number of discordant pairs.
+mergeSort :: GM.MVector v e
           => (e -> e -> Ordering)
-          -> v (PrimState m) e 
-          -> v (PrimState m) e 
-          -> PrimRef m Integer
-          -> m ()
-{-# INLINE mergeSort #-}
+          -> v s e 
+          -> v s e 
+          -> STRef s Integer
+          -> ST s ()
 mergeSort cmp src buf count = loop 0 (GM.length src - 1)
     where
         loop l u 
               case cmp eL eU of
                   GT -> do GM.unsafeWrite src l eU
                            GM.unsafeWrite src u eL
-                           modifyPrimRef' count (+1) 
+                           modifySTRef' count (+1) 
                   _ -> return ()
           | otherwise  = do
               let mid = (u + l) `shiftR` 1
               loop l mid
               loop mid u
               merge cmp (GM.unsafeSlice l (u-l+1) src) buf (mid - l) count
+{-# INLINE mergeSort #-}
 
-merge :: (PrimMonad m, GM.MVector v e)
+merge :: GM.MVector v e
       => (e -> e -> Ordering)
-      -> v (PrimState m) e
-      -> v (PrimState m) e
+      -> v s e
+      -> v s e
       -> Int
-      -> PrimRef m Integer
-      -> m ()
-{-# INLINE merge #-}
+      -> STRef s Integer
+      -> ST s ()
 merge cmp src buf mid count = do GM.unsafeCopy tmp lower
                                  eTmp <- GM.unsafeRead tmp 0
                                  eUpp <- GM.unsafeRead upper 0
 
     loop !low !iLow !eLow !high !iHigh !eHigh !iIns = case cmp eHigh eLow of
         LT -> do GM.unsafeWrite src iIns eHigh
-                 modifyPrimRef' count (+ fromIntegral (GM.length low - iLow))
+                 modifySTRef' count (+ fromIntegral (GM.length low - iLow))
                  wroteHigh low iLow eLow high (iHigh+1) (iIns+1)
         _  -> do GM.unsafeWrite src iIns eLow
                  wroteLow low (iLow+1) high iHigh eHigh (iIns+1)
+{-# INLINE merge #-}
 
 -- $references
 --
 -- * William R. Knight. (1966) A computer method for calculating Kendall's Tau
 --   with ungrouped data. /Journal of the American Statistical Association/,
---   Vol. 61, No. 314, Part 1, pp. 436-439. (http://www.jstor.org/pss/2282833).
+--   Vol. 61, No. 314, Part 1, pp. 436-439. <http://www.jstor.org/pss/2282833>
 --
     primitive         >= 0.3,
     vector            >= 0.7.1,
     vector-algorithms >= 0.4,
-    vector-binary-instances >= 0.2.1,
-    monad-primitive
+    vector-binary-instances >= 0.2.1
   if impl(ghc < 7.6)
     build-depends:
       ghc-prim