Commits

Anonymous committed 57a8bb9

add tests for Kendall, and fix one bug

  • Participants
  • Parent commits 638370e

Comments (0)

Files changed (3)

Statistics/Correlation/Kendall.hs

 -- $n_c = number of concordant pairs$, $n_d = number of discordant pairs$.
 
 module Statistics.Correlation.Kendall 
-    ( kendall ) where
+    ( kendall
+
+    -- * References
+    -- $references
+    ) where
 
 import qualified Data.Vector.Algorithms.Intro as I
 import qualified Data.Vector.Generic as G
   | otherwise  = runST $ do
     xy <- G.thaw xy'
     let n = GM.length xy
-        n_0 = (fromIntegral n * (fromIntegral n-1)) `shiftR` 1 :: Integer
-    n_dis <- newSTRef 0
+    n_dRef <- newSTRef 0
     I.sort xy
-    equalX <- numOfTiesBy ((==) `on` fst) xy
+    tieX <- numOfTiesBy ((==) `on` fst) xy
+    tieXY <- numOfTiesBy (==) xy
     tmp <- GM.new n
-    mergeSort (compare `on` snd) xy tmp n_dis
-    equalY <- numOfTiesBy ((==) `on` snd) xy
-    n_d <- readSTRef n_dis
-    let nu = n_0 - n_d - equalX - equalY - n_d
-        de = (n_0 - equalX) * (n_0 - equalY)
-    return $ fromIntegral nu / (sqrt.fromIntegral) de
+    mergeSort (compare `on` snd) xy tmp n_dRef
+    tieY <- numOfTiesBy ((==) `on` snd) xy
+    n_d <- readSTRef n_dRef
+    let n_0 = (fromIntegral n * (fromIntegral n-1)) `shiftR` 1 :: Integer
+        n_c = n_0 - n_d - tieX - tieY + tieXY
+    return $ fromIntegral (n_c - n_d) /
+             (sqrt.fromIntegral) ((n_0 - tieX) * (n_0 - tieY))
 {-# INLINE kendall #-}
 
 -- calculate number of tied pairs in a sorted vector

tests/Tests/Correlation.hs

+{-#LANGUAGE BangPatterns #-}
+
+module Tests.Correlation
+    ( tests ) where
+
+import Test.Framework
+import Test.Framework.Providers.QuickCheck2
+import qualified Data.Vector as V
+import Statistics.Correlation.Kendall
+
+tests :: Test
+tests = testGroup "Correlation"
+    [ testProperty "Kendall" testKendall
+    ]
+
+testKendall :: [(Double, Double)] -> Bool
+testKendall xy | isNaN r1 = isNaN r2
+               | otherwise = r1 == r2
+  where
+    r1 = kendallBruteForce xy 
+    r2 = kendall $ V.fromList xy
+
+kendallBruteForce :: [(Double, Double)] -> Double
+kendallBruteForce xy = (n_c - n_d) / sqrt ((n_0 - n_1) * (n_0 - n_2))
+  where
+    allPairs = f xy
+    (n_c, n_d, n_1, n_2) = foldl g (0,0,0,0) allPairs
+    n_0 = fromIntegral.length $ allPairs
+    g (!nc, !nd, !n1, !n2) ((x1, y1), (x2, y2))
+      | (x2 - x1) * (y2 - y1) > 0 = (nc+1, nd, n1, n2)
+      | (x2 - x1) * (y2 - y1) < 0 = (nc, nd+1, n1, n2)
+      | otherwise = if x1 == x2
+                       then if y1 == y2
+                               then (nc, nd, n1, n2)
+                               else (nc, nd, n1+1, n2)
+                       else (nc, nd, n1, n2+1)
+    f (x:xs) = zip (repeat x) xs ++ f xs
+    f _ = []
 import qualified Tests.KDE as KDE
 import qualified Tests.NonParametric as NonParametric
 import qualified Tests.Transform as Transform
+import qualified Tests.Correlation as Correlation
 
 main :: IO ()
 main = defaultMain [ Distribution.tests
                    , KDE.tests
                    , NonParametric.tests
                    , Transform.tests
+                   , Correlation.tests
                    ]