Commits

dafis committed d7f0e10

Added modular square roots and chinese remainder theorem

  • Participants
  • Parent commits 1e799dc

Comments (0)

Files changed (1)

Math/NumberTheory/Moduli.hs

     , invertMod
     , powerMod
     , powerModInteger
+    , chineseRemainder
+      -- ** Partially checked input
+    , sqrtModP
       -- * Unchecked functions
     , jacobi'
     , powerMod'
     , powerModInteger'
+    , sqrtModPList
+    , sqrtModP'
+    , tonelliShanks
+    , sqrtModPP
+    , sqrtModPPList
+    , sqrtModF
+    , sqrtModFList
+    , chineseRemainder2
     ) where
 
 #include "MachDeps.h"
 import Data.Bits
 import Data.Array.Unboxed
 import Data.Array.Base (unsafeAt)
+import Data.Maybe (fromJust)
+import Data.List (foldl', nub)
+import Control.Monad (foldM, liftM2)
 
-import Math.NumberTheory.Utils (shiftToOddCount)
+import Math.NumberTheory.Utils (shiftToOddCount, splitOff)
+import Math.NumberTheory.GCD (extendedGCD)
+import Math.NumberTheory.Primes.Heap (sieveFrom)
+-- Guesstimated startup time for the Heap algorithm is lower than
+-- the cost to sieve an entire chunk.
 
 -- | Invert a number relative to a modulus.
 --   If @number@ and @modulus@ are coprime, the result is
 
 #endif
 
+-- | @sqrtModP n prime@ calculates a modular square root of @n@ modulo @prime@
+--   if that exists. The second argument /must/ be a (positive) prime, otherwise
+--   the computation may not terminate and if it does, may yield a wrong result.
+--   The precondition is /not/ checked.
+--
+--   If @prime@ is a prime and @n@ a quadratic residue modulo @prime@, the result
+--   is @Just r@ where @r^2 ≡ n (mod prime)@, if @n@ is a quadratic nonresidue,
+--   the result is @Nothing@.
+sqrtModP :: Integer -> Integer -> Maybe Integer
+sqrtModP n 2 = Just (n `mod` 2)
+sqrtModP n prime = case jacobi' n prime of
+                     0 -> Just 0
+                     1 -> Just (tonelliShanks (n `mod` prime) prime)
+                     _ -> Nothing
+
+-- | @sqrtModPList n prime@ computes the list of all square roots of @n@
+--   modulo @prime@. @prime@ /must/ be a (positive) prime.
+--   The precondition is /not/ checked.
+sqrtModPList :: Integer -> Integer -> [Integer]
+sqrtModPList n prime
+    | prime == 2    = [n `mod` 2]
+    | otherwise     = case sqrtModP n prime of
+                        Just 0 -> [0]
+                        Just r -> [r,prime-r] -- The group of units in Z/(p) is cyclic
+                        _      -> []
+
+-- | @sqrtModP' square prime@ finds a square root of @square@ modulo
+--   prime. @prime@ /must/ be a (positive) prime, and @sqaure@ /must/ be a
+--   quadratic residue modulo @prime@, i.e. @'jacobi square prime == 1@.
+--   The precondition is /not/ checked.
+sqrtModP' :: Integer -> Integer -> Integer
+sqrtModP' square prime
+    | prime == 2    = square
+    | rem4 prime == 3 = powerModInteger' square ((prime + 1) `quot` 4) prime
+    | otherwise     = tonelliShanks square prime
+
+-- | @tonelliShanks square prime@ calculates a square root of @square@
+--   modulo @prime@, where @prime@ is a prime of the form @4*k + 1@ and
+--   @square@ is a quadratic residue modulo @prime@, using the
+--   Tonelli-Shanks algorithm.
+--   No checks on the input are performed.
+tonelliShanks :: Integer -> Integer -> Integer
+tonelliShanks square prime = loop rc t1 generator log2
+  where
+    (log2,q) = shiftToOddCount (prime-1)
+    nonSquare = findNonSquare prime
+    generator = powerModInteger' nonSquare q prime
+    rc = powerModInteger' square ((q+1) `quot` 2) prime
+    t1 = powerModInteger' square q prime
+    msqr x = (x*x) `rem` prime
+    msquare 0 x = x
+    msquare k x = msquare (k-1) (msqr x)
+    findPeriod per 1 = per
+    findPeriod per x = findPeriod (per+1) (msqr x)
+    loop !r t c m
+        | t == 1    = r
+        | otherwise = loop nextR nextT nextC nextM
+          where
+            nextM = findPeriod 0 t
+            b     = msquare (m - 1 - nextM) c
+            nextR = (r*b) `rem` prime
+            nextC = msqr b
+            nextT = (t*nextC) `rem` prime
+
+-- | @sqrtModPP n (prime,expo)@ calculates a square root of @n@
+--   modulo @prime^expo@ if one exists. @prime@ /must/ be a
+--   (positive) prime. @expo@ must be positive, @n@ must be coprime
+--   to @prime@
+sqrtModPP :: Integer -> (Integer,Int) -> Maybe Integer
+sqrtModPP n (2,e) = sqM2P n e
+sqrtModPP n (prime,expo) = case sqrtModP n prime of
+                             Just r -> Just $ fixup r
+                             _      -> Nothing
+  where
+    fixup r = case splitOff prime (r*r-n) of
+                (e,q) | expo <= e -> r
+                      | otherwise -> hoist (fromJust $ invertMod (2*r) prime) r (q `mod` prime) (prime^e)
+                      --
+    hoist inv root elim pp
+        | expo <= ex    = root'
+        | otherwise     = hoist inv root' (nelim `mod` prime) (prime^ex)
+          where
+            root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp)
+            (ex, nelim) = splitOff prime (root'*root' - n)
+
+-- dirty, dirty
+sqM2P :: Integer -> Int -> Maybe Integer
+sqM2P n e
+    | e < 2     = Just (n `mod` 2)
+    | n' == 0   = Just 0
+    | e <= k    = Just 0
+    | odd k     = Nothing
+    | otherwise = fmap ((`mod` mdl) . (`shiftL` k2)) $ solve s e2
+      where
+        mdl = 1 `shiftL` e
+        n' = n `mod` mdl
+        (k,s) = shiftToOddCount n'
+        k2 = k `quot` 2
+        e2 = e-k
+        solve r 1 = Just 1
+        solve 1 _ = Just 1
+        solve r p
+            | rem4 r == 3   = Nothing  -- otherwise r ≡ 1 (mod 4)
+            | p == 2        = Just 1   -- otherwise p >= 3
+            | rem8 r == 5   = Nothing  -- otherwise r ≡ 1 (mod 8)
+            | otherwise     = fixup r (fst $ shiftToOddCount (r-1))
+              where
+                fixup x pw
+                    | pw >= e2  = Just x
+                    | otherwise = fixup x' pw'
+                      where
+                        x' = x + (1 `shiftL` (pw-1))
+                        d = x'*x' - r
+                        pw' = if d == 0 then e2 else fst (shiftToOddCount d)
+
+-- | @sqrtModF n primePowers@ calculates a square root of @n@ modulo
+--   @product [p^k | (p,k) <- primePowers]@ if one exists and all primes
+--   are distinct.
+sqrtModF :: Integer -> [(Integer,Int)] -> Maybe Integer
+sqrtModF n pps = do roots <- mapM (sqrtModPP n) pps
+                    chineseRemainder $ zip roots (map (uncurry (^)) pps)
+
+-- | @sqrtModFList n primePowers@ calculates all square roots of @n@ modulo
+--   @product [p^k | (p,k) <- primePowers]@ if all primes are distinct.
+sqrtModFList :: Integer -> [(Integer,Int)] -> [Integer]
+sqrtModFList n pps = map fst $ foldl1 (liftM2 comb) cs
+  where
+    ms :: [Integer]
+    ms = map (uncurry (^)) pps
+    rs :: [[Integer]]
+    rs = map (sqrtModPPList n) pps
+    cs :: [[(Integer,Integer)]]
+    cs = zipWith (\l m -> map (\x -> (x,m)) l) rs ms
+    comb t1@(_,m1) t2@(_,m2) = (chineseRemainder2 t1 t2,m1*m2)
+
+-- | @sqrtModPPList n (prime,expo)@ calculates the list of all
+--   square roots of @n@ modulo @prime^expo@. The same restriction
+--   as in 'sqrtModPP' applies to the arguments.
+sqrtModPPList :: Integer -> (Integer,Int) -> [Integer]
+sqrtModPPList n (2,expo)
+    = case sqM2P n expo of
+        Just r -> let m = 1 `shiftL` (expo-1)
+                  in nub [r, (r+m) `mod` (2*m), (m-r) `mod` (2*m), 2*m-r]
+sqrtModPPList n pe@(prime,expo)
+    = case sqrtModPP n pe of
+        Just 0 -> [0]
+        Just r -> [prime^expo - r, r] -- The group of units in Z/(p^e) is cyclic
+        _      -> []
+
+-- | Given a list @[(r_1,m_1), ..., (r_n,m_n)]@ of @(residue,modulus)@
+--   pairs, @chineseRemainder@ calculates the solution to the simultaneous
+--   congruences
+--
+-- >
+-- > r ≡ r_k (mod m_k)
+-- >
+--
+--   if all moduli are pairwise coprime. If not all moduli are
+--   pairwise coprime, the result is @Nothing@ regardless of whether
+--   a solution exists.
+chineseRemainder :: [(Integer,Integer)] -> Maybe Integer
+chineseRemainder remainders = foldM addRem 0 remainders
+  where
+    !modulus = product (map snd remainders)
+    addRem acc (r,m) = do
+        let cf = modulus `quot` m
+        inv <- invertMod cf m
+        Just $! (acc + inv*cf*r) `mod` modulus
+
+-- | @chineseRemainder2 (r_1,m_1) (r_2,m_2)@ calculates the solution of
+--
+-- >
+-- > r ≡ r_k (mod m_k)
+--
+--   if @m_1@ and @m_2@ are coprime.
+chineseRemainder2 :: (Integer,Integer) -> (Integer,Integer) -> Integer
+chineseRemainder2 (r1, md1) (r2,md2)
+    = case extendedGCD md1 md2 of
+        (_,u,v) -> ((1 - u*md1)*r1 + (1 - v*md2)*r2) `mod` (md1*md2)
+
 -- Utilities
 
 -- For large Integers, going via Int is much faster than bit-fiddling
 
 jac2 :: UArray Int Int
 jac2 = array (0,7) [(0,0),(1,1),(2,0),(3,-1),(4,0),(5,-1),(6,0),(7,1)]
+
+findNonSquare :: Integer -> Integer
+findNonSquare n
+    | rem8 n == 5 || rem8 n == 3  = 2
+    | otherwise = search primelist
+      where
+        primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
+                        ++ sieveFrom (68 + n `rem` 4) -- prevent sharing
+        search (p:ps)
+            | jacobi' p n == -1 = p
+            | otherwise         = search ps
+        search _ = error "Should never have happened, prime list exhausted."
+
+modProd :: [Integer] -> Integer -> Integer
+modProd fs md = foldl' (\a b -> (a*b) `mod` md) 1 fs