Aleksey Khudyakov avatar Aleksey Khudyakov committed 42a315d

Change interpolation API

Comments (0)

Files changed (1)

Numeric/Tools/Interpolation.hs

-{-# LANGUAGE FlexibleContexts   #-}
-{-# LANGUAGE DeriveDataTypeable #-}
-{-# LANGUAGE TypeFamilies       #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE FlexibleContexts      #-}
+{-# LANGUAGE FlexibleInstances     #-}
+{-# LANGUAGE DeriveDataTypeable    #-}
+{-# LANGUAGE TypeFamilies          #-}
 -- |
 -- Module    : Numeric.Tools.Interpolation
 -- Copyright : (c) 2011 Aleksey Khudyakov
     -- * Cubic splines
   , CubicSpline
   , cubicSpline
-    --
+    -- * Reexport of mesh type
   , module Numeric.Tools.Mesh
+    -- * Default methods
+  , defaultInterpSize
+  , defaultInterpIndex
   ) where
 
 import Control.Monad.ST   (runST)
 
 ----------------------------------------------------------------
 
--- | Interpolation for arbitraty 1D meshes. Data type which perform
---   interpolations is parametrized by mesh type.
-class Interpolation a where
+-- | Type class for Interpolation algorithms. Since some algorithms
+--   require some particular mesh type it's present as type class
+--   parameter. Every algorithms should be instance of 'Indexable' as
+--   well. Indexing should return pair @(x,y)@ for u'th mesh node.
+class ( IndexVal (interp mesh) ~ (Double,Double), Indexable (interp mesh)
+      , IndexVal mesh ~ Double, Mesh mesh
+      ) => Interpolation interp mesh where
   -- | Interpolate function at some point. Function should not
   --   fail outside of mesh however it may and most likely will give
   --   nonsensical results
-  at          :: (IndexVal m ~ Double, Mesh m) => a m -> Double -> Double
+  at          :: interp mesh -> Double -> Double
   -- | Use table of already evaluated function and mesh. Sizes of mesh
   --   and table must coincide but it's not checked. Do not use this
   --   function use 'tabulate' instead.
-  unsafeTabulate :: (IndexVal m ~ Double, Mesh m, G.Vector v Double) => m -> v Double -> a m
+  unsafeTabulate :: (G.Vector v Double) => mesh -> v Double -> interp mesh
   -- | Get mesh.
-  interpolationMesh  :: a m -> m
+  interpolationMesh  :: interp mesh -> mesh
   -- | Get table of function values 
-  interpolationTable :: a m -> U.Vector Double
+  interpolationTable :: interp mesh -> U.Vector Double
     
 -- | Tabulate function.
-tabulateFun :: (IndexVal m ~ Double, Mesh m, Interpolation a) => m -> (Double -> Double) -> a m
+tabulateFun :: (Interpolation i m) => m -> (Double -> Double) -> i m
 tabulateFun mesh f = unsafeTabulate mesh $ U.generate (size mesh) (f . unsafeIndex mesh)
 {-# INLINE tabulateFun #-}
 
 -- | Use table of already evaluated function and mesh. Sizes of mesh
 --   and table must coincide. 
-tabulate :: (Interpolation a, IndexVal m ~ Double, Mesh m, G.Vector v Double) => m -> v Double -> a m
+tabulate :: (Interpolation i m, G.Vector v Double) => m -> v Double -> i m
 {-# INLINE tabulate #-}
 tabulate mesh tbl
   | size mesh /= G.length tbl = error "Numeric.Tools.Interpolation.tabulate: size of vector and mesh do not match"
 ----------------------------------------------------------------
 
 -- | Data for linear interpolation
-data LinearInterp a = LinearInterp { linearInterpMesh  :: a
-                                   , linearInterpTable :: U.Vector Double
-                                   }
-                      deriving (Show,Eq,Data,Typeable)
+data LinearInterp mesh = LinearInterp
+  { linearInterpMesh  :: mesh
+  , linearInterpTable :: U.Vector Double
+  }
+  deriving (Show,Eq,Data,Typeable)
 
 -- | Function used to fix types
-linearInterp :: LinearInterp a -> LinearInterp a
+linearInterp :: LinearInterp mesh -> LinearInterp mesh
 linearInterp = id
 
-instance Mesh a => Indexable (LinearInterp a) where
-  type IndexVal (LinearInterp a) = (IndexVal a, Double)
-  size        (LinearInterp _    vec)   = size vec
-  unsafeIndex (LinearInterp mesh vec) i = ( unsafeIndex mesh i
-                                          , unsafeIndex vec  i
-                                          )
+instance (Mesh mesh, IndexVal mesh ~ Double) => Indexable (LinearInterp mesh) where
+  type IndexVal (LinearInterp mesh) = (IndexVal mesh, Double)
+  size        = defaultInterpSize
+  unsafeIndex = defaultInterpIndex
   {-# INLINE size        #-}
   {-# INLINE unsafeIndex #-}
 
-instance Interpolation LinearInterp where
+instance (Mesh mesh, IndexVal mesh ~ Double) => Interpolation LinearInterp mesh where
   at                      = linearInterpolation
   unsafeTabulate mesh tbl = LinearInterp mesh (G.convert tbl)
   interpolationMesh       = linearInterpMesh
 -- | Natural cubic splines
 data CubicSpline a = CubicSpline { cubicSplineMesh   :: a
                                  , cubicSplineTable  :: U.Vector Double
-                                 , cubicSplineY2     :: U.Vector Double
+                                 , _cubicSplineY2    :: U.Vector Double
                                  }
                    deriving (Eq,Show,Data,Typeable)
 
 cubicSpline :: CubicSpline a -> CubicSpline a 
 cubicSpline = id
 
-instance Interpolation CubicSpline where
+instance (Mesh mesh, IndexVal mesh ~ Double) => Indexable (CubicSpline mesh) where
+  type IndexVal (CubicSpline mesh) = (IndexVal mesh, Double)
+  size        = defaultInterpSize
+  unsafeIndex = defaultInterpIndex
+  {-# INLINE size        #-}
+  {-# INLINE unsafeIndex #-}
+
+instance (Mesh mesh, IndexVal mesh ~ Double) => Interpolation CubicSpline mesh where
   at (CubicSpline mesh ys y2) x = y
     where
     i  = safeFindIndex mesh x
     where
       n = size mesh - 2
 {-# INLINE safeFindIndex #-}
+
+-- | Default implementation of 'size' for interpolation algorithms.
+defaultInterpSize :: Interpolation i m => i m -> Int
+defaultInterpSize = U.length . interpolationTable
+{-# INLINE defaultInterpSize #-}
+
+-- | Default implementation of 'unsafeIndex' for interpolation algorithms.
+defaultInterpIndex :: Interpolation i m => i m -> Int -> (Double, Double)
+defaultInterpIndex tbl i = ( unsafeIndex (interpolationMesh  tbl) i
+                           , unsafeIndex (interpolationTable tbl) i
+                           )
+{-# INLINE defaultInterpIndex #-}
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.