# Commits

committed d7f0e10

Added modular square roots and chinese remainder theorem

• Participants
• Parent commits 1e799dc

# Math/NumberTheory/Moduli.hs

`     , invertMod`
`     , powerMod`
`     , powerModInteger`
`+    , chineseRemainder`
`+      -- ** Partially checked input`
`+    , sqrtModP`
`       -- * Unchecked functions`
`     , jacobi'`
`     , powerMod'`
`     , powerModInteger'`
`+    , sqrtModPList`
`+    , sqrtModP'`
`+    , tonelliShanks`
`+    , sqrtModPP`
`+    , sqrtModPPList`
`+    , sqrtModF`
`+    , sqrtModFList`
`+    , chineseRemainder2`
`     ) where`
` `
` #include "MachDeps.h"`
` import Data.Bits`
` import Data.Array.Unboxed`
` import Data.Array.Base (unsafeAt)`
`+import Data.Maybe (fromJust)`
`+import Data.List (foldl', nub)`
`+import Control.Monad (foldM, liftM2)`
` `
`-import Math.NumberTheory.Utils (shiftToOddCount)`
`+import Math.NumberTheory.Utils (shiftToOddCount, splitOff)`
`+import Math.NumberTheory.GCD (extendedGCD)`
`+import Math.NumberTheory.Primes.Heap (sieveFrom)`
`+-- Guesstimated startup time for the Heap algorithm is lower than`
`+-- the cost to sieve an entire chunk.`
` `
` -- | Invert a number relative to a modulus.`
` --   If @number@ and @modulus@ are coprime, the result is`
` `
` #endif`
` `
`+-- | @sqrtModP n prime@ calculates a modular square root of @n@ modulo @prime@`
`+--   if that exists. The second argument /must/ be a (positive) prime, otherwise`
`+--   the computation may not terminate and if it does, may yield a wrong result.`
`+--   The precondition is /not/ checked.`
`+--`
`+--   If @prime@ is a prime and @n@ a quadratic residue modulo @prime@, the result`
`+--   is @Just r@ where @r^2 ≡ n (mod prime)@, if @n@ is a quadratic nonresidue,`
`+--   the result is @Nothing@.`
`+sqrtModP :: Integer -> Integer -> Maybe Integer`
`+sqrtModP n 2 = Just (n `mod` 2)`
`+sqrtModP n prime = case jacobi' n prime of`
`+                     0 -> Just 0`
`+                     1 -> Just (tonelliShanks (n `mod` prime) prime)`
`+                     _ -> Nothing`
`+`
`+-- | @sqrtModPList n prime@ computes the list of all square roots of @n@`
`+--   modulo @prime@. @prime@ /must/ be a (positive) prime.`
`+--   The precondition is /not/ checked.`
`+sqrtModPList :: Integer -> Integer -> [Integer]`
`+sqrtModPList n prime`
`+    | prime == 2    = [n `mod` 2]`
`+    | otherwise     = case sqrtModP n prime of`
`+                        Just 0 -> [0]`
`+                        Just r -> [r,prime-r] -- The group of units in Z/(p) is cyclic`
`+                        _      -> []`
`+`
`+-- | @sqrtModP' square prime@ finds a square root of @square@ modulo`
`+--   prime. @prime@ /must/ be a (positive) prime, and @sqaure@ /must/ be a`
`+--   quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@.`
`+--   The precondition is /not/ checked.`
`+sqrtModP' :: Integer -> Integer -> Integer`
`+sqrtModP' square prime`
`+    | prime == 2    = square`
`+    | rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime`
`+    | otherwise     = tonelliShanks square prime`
`+`
`+-- | @tonelliShanks square prime@ calculates a square root of @square@`
`+--   modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and`
`+--   @square@ is a quadratic residue modulo @prime@, using the`
`+--   Tonelli-Shanks algorithm.`
`+--   No checks on the input are performed.`
`+tonelliShanks :: Integer -> Integer -> Integer`
`+tonelliShanks square prime = loop rc t1 generator log2`
`+  where`
`+    (log2,q) = shiftToOddCount (prime-1)`
`+    nonSquare = findNonSquare prime`
`+    generator = powerModInteger' nonSquare q prime`
`+    rc = powerModInteger' square ((q+1) `quot` 2) prime`
`+    t1 = powerModInteger' square q prime`
`+    msqr x = (x*x) `rem` prime`
`+    msquare 0 x = x`
`+    msquare k x = msquare (k-1) (msqr x)`
`+    findPeriod per 1 = per`
`+    findPeriod per x = findPeriod (per+1) (msqr x)`
`+    loop !r t c m`
`+        | t == 1    = r`
`+        | otherwise = loop nextR nextT nextC nextM`
`+          where`
`+            nextM = findPeriod 0 t`
`+            b     = msquare (m - 1 - nextM) c`
`+            nextR = (r*b) `rem` prime`
`+            nextC = msqr b`
`+            nextT = (t*nextC) `rem` prime`
`+`
`+-- | @sqrtModPP n (prime,expo)@ calculates a square root of @n@`
`+--   modulo @prime^expo@ if one exists. @prime@ /must/ be a`
`+--   (positive) prime. @expo@ must be positive, @n@ must be coprime`
`+--   to @prime@`
`+sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer`
`+sqrtModPP n (2,e) = sqM2P n e`
`+sqrtModPP n (prime,expo) = case sqrtModP n prime of`
`+                             Just r -> Just \$ fixup r`
`+                             _      -> Nothing`
`+  where`
`+    fixup r = case splitOff prime (r*r-n) of`
`+                (e,q) | expo <= e -> r`
`+                      | otherwise -> hoist (fromJust \$ invertMod (2*r) prime) r (q `mod` prime) (prime^e)`
`+                      --`
`+    hoist inv root elim pp`
`+        | expo <= ex    = root'`
`+        | otherwise     = hoist inv root' (nelim `mod` prime) (prime^ex)`
`+          where`
`+            root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp)`
`+            (ex, nelim) = splitOff prime (root'*root' - n)`
`+`
`+-- dirty, dirty`
`+sqM2P :: Integer -> Int -> Maybe Integer`
`+sqM2P n e`
`+    | e < 2     = Just (n `mod` 2)`
`+    | n' == 0   = Just 0`
`+    | e <= k    = Just 0`
`+    | odd k     = Nothing`
`+    | otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) \$ solve s e2`
`+      where`
`+        mdl = 1 `shiftL` e`
`+        n' = n `mod` mdl`
`+        (k,s) = shiftToOddCount n'`
`+        k2 = k `quot` 2`
`+        e2 = e-k`
`+        solve r 1 = Just 1`
`+        solve 1 _ = Just 1`
`+        solve r p`
`+            | rem4 r == 3   = Nothing  -- otherwise r ≡ 1 (mod 4)`
`+            | p == 2        = Just 1   -- otherwise p >= 3`
`+            | rem8 r == 5   = Nothing  -- otherwise r ≡ 1 (mod 8)`
`+            | otherwise     = fixup r (fst \$ shiftToOddCount (r-1))`
`+              where`
`+                fixup x pw`
`+                    | pw >= e2  = Just x`
`+                    | otherwise = fixup x' pw'`
`+                      where`
`+                        x' = x + (1 `shiftL` (pw-1))`
`+                        d = x'*x' - r`
`+                        pw' = if d == 0 then e2 else fst (shiftToOddCount d)`
`+`
`+-- | @sqrtModF n primePowers@ calculates a square root of @n@ modulo`
`+--   @product [p^k | (p,k) <- primePowers]@ if one exists and all primes`
`+--   are distinct.`
`+sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer`
`+sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps`
`+                    chineseRemainder \$ zip roots (map (uncurry (^)) pps)`
`+`
`+-- | @sqrtModFList n primePowers@ calculates all square roots of @n@ modulo`
`+--   @product [p^k | (p,k) <- primePowers]@ if all primes are distinct.`
`+sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer]`
`+sqrtModFList n pps = map fst \$ foldl1 (liftM2 comb) cs`
`+  where`
`+    ms :: [Integer]`
`+    ms = map (uncurry (^)) pps`
`+    rs :: [[Integer]]`
`+    rs = map (sqrtModPPList n) pps`
`+    cs :: [[(Integer,Integer)]]`
`+    cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms`
`+    comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2)`
`+`
`+-- | @sqrtModPPList n (prime,expo)@ calculates the list of all`
`+--   square roots of @n@ modulo @prime^expo@. The same restriction`
`+--   as in 'sqrtModPP' applies to the arguments.`
`+sqrtModPPList :: Integer -> (Integer,Int) -> [Integer]`
`+sqrtModPPList n (2,expo)`
`+    = case sqM2P n expo of`
`+        Just r -> let m = 1 `shiftL` (expo-1)`
`+                  in nub [r, (r+m) `mod` (2*m), (m-r) `mod` (2*m), 2*m-r]`
`+sqrtModPPList n pe@(prime,expo)`
`+    = case sqrtModPP n pe of`
`+        Just 0 -> [0]`
`+        Just r -> [prime^expo - r, r] -- The group of units in Z/(p^e) is cyclic`
`+        _      -> []`
`+`
`+-- | Given a list @[(r_1,m_1), ..., (r_n,m_n)]@ of @(residue,modulus)@`
`+--   pairs, @chineseRemainder@ calculates the solution to the simultaneous`
`+--   congruences`
`+--`
`+-- >`
`+-- > r ≡ r_k (mod m_k)`
`+-- >`
`+--`
`+--   if all moduli are pairwise coprime. If not all moduli are`
`+--   pairwise coprime, the result is @Nothing@ regardless of whether`
`+--   a solution exists.`
`+chineseRemainder :: [(Integer,Integer)] -> Maybe Integer`
`+chineseRemainder remainders = foldM addRem 0 remainders`
`+  where`
`+    !modulus = product (map snd remainders)`
`+    addRem acc (r,m) = do`
`+        let cf = modulus `quot` m`
`+        inv <- invertMod cf m`
`+        Just \$! (acc + inv*cf*r) `mod` modulus`
`+`
`+-- | @chineseRemainder2 (r_1,m_1) (r_2,m_2)@ calculates the solution of`
`+--`
`+-- >`
`+-- > r ≡ r_k (mod m_k)`
`+--`
`+--   if @m_1@ and @m_2@ are coprime.`
`+chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer`
`+chineseRemainder2 (r1, md1) (r2,md2)`
`+    = case extendedGCD md1 md2 of`
`+        (_,u,v) -> ((1 - u*md1)*r1 + (1 - v*md2)*r2) `mod` (md1*md2)`
`+`
` -- Utilities`
` `
` -- For large Integers, going via Int is much faster than bit-fiddling`
` `
` jac2 :: UArray Int Int`
` jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,-1),(4,0),(5,-1),(6,0),(7,1)]`
`+`
`+findNonSquare :: Integer -> Integer`
`+findNonSquare n`
`+    | rem8 n == 5 || rem8 n == 3  = 2`
`+    | otherwise = search primelist`
`+      where`
`+        primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]`
`+                        ++ sieveFrom (68 + n `rem` 4) -- prevent sharing`
`+        search (p:ps)`
`+            | jacobi' p n == -1 = p`
`+            | otherwise         = search ps`
`+        search _ = error "Should never have happened, prime list exhausted."`
`+`
`+modProd :: [Integer] -> Integer -> Integer`
`+modProd fs md = foldl' (\a b -> (a*b) `mod` md) 1 fs`