diff --git a/poly.cabal b/poly.cabal index 126530b..b9e5bb8 100644 --- a/poly.cabal +++ b/poly.cabal @@ -31,6 +31,7 @@ library hs-source-dirs: src exposed-modules: Data.Poly + Data.Poly.Interpolation Data.Poly.Laurent Data.Poly.Semiring Data.Poly.Orthogonal @@ -64,6 +65,7 @@ library build-depends: base >= 4.12 && < 5, + containers >= 0.5.4, deepseq >= 1.1 && < 1.6, primitive >= 0.6, semirings >= 0.5.2, @@ -89,6 +91,7 @@ test-suite poly-tests Dense DenseLaurent DFT + Interpolation Orthogonal Quaternion TestUtils @@ -100,6 +103,7 @@ test-suite poly-tests SparseLaurent build-depends: base >=4.10 && <5, + containers, mod >=0.1.2, poly, QuickCheck >=2.12 && <2.14.3, diff --git a/src/Data/Poly/Internal/Dense.hs b/src/Data/Poly/Internal/Dense.hs index 04e0636..3f6f8ab 100644 --- a/src/Data/Poly/Internal/Dense.hs +++ b/src/Data/Poly/Internal/Dense.hs @@ -27,6 +27,7 @@ module Data.Poly.Internal.Dense , scale , pattern X , eval + , evalk , subst , deriv , integral @@ -460,6 +461,14 @@ substitute' f (Poly cs) x = fst' $ G.foldl' (\(acc :*: xn) cn -> acc `plus` f cn xn :*: x `times` xn) (zero :*: one) cs {-# INLINE substitute' #-} +-- | Evaluate the kth derivative of the polynomial at a given point. +evalk :: (G.Vector v a, Num a) => Int -> Poly v a -> a -> a +evalk k (Poly cs) x = fst' $ + G.ifoldl' (\(acc :*: xn) i cn -> acc + kth i * cn * xn :*: x * xn) (0 :*: 1) (G.drop k cs) + where + kth i = fromIntegral $ product [(i + 1)..(i + k)] +{-# INLINE evalk #-} + -- | Take the derivative of the polynomial. -- -- >>> deriv (X^3 + 3 * X) :: UPoly Int diff --git a/src/Data/Poly/Interpolation.hs b/src/Data/Poly/Interpolation.hs new file mode 100644 index 0000000..548fd30 --- /dev/null +++ b/src/Data/Poly/Interpolation.hs @@ -0,0 +1,38 @@ +module Data.Poly.Interpolation + ( lagrange + , hermite + ) where + +import Prelude hiding (Foldable(..)) +import qualified Data.Foldable as F + +import Data.Map + +import Data.Poly.Internal.Dense +import qualified Data.Vector.Generic as G + +-- | Compute the [Lagrange interpolating polynomial](https://en.wikipedia.org/wiki/Lagrange_polynomial). +-- +-- This is the (unique) polynomial of minimal degree interpolating the given points. +-- The keys are the @x@ values and the associated @y@ is the value at @x@. +lagrange :: (G.Vector v a, Eq a, Fractional a) => Map a a -> Poly v a +lagrange = fst . foldlWithKey' f (0, 1) + where + f (p, w) x y = + let a = (y - eval p x) / eval w x + in (p + scale 0 a w, scale 1 1 w - scale 0 x w) -- (p + a * w, w * (X - x)) +{-# INLINABLE lagrange #-} + +-- | Compute the [Hermite interpolating polynomial](https://en.wikipedia.org/wiki/Hermite_interpolation). +-- +-- This is the (unique) polynomial of minimal degree interpolating the given points and derivatives. +-- The keys are the @x@ values and the associated @ys@ are the values and derivatives at @x@, where @ys !! k@ is the k-th derivative. +hermite :: (G.Vector v a, Eq a, Fractional a) => Map a [a] -> Poly v a +hermite = fst . foldlWithKey' f (0, 1) + where + f (p, w) x ys = let (_, p', w') = F.foldl' g (0, p, w) ys in (p', w') + where + g (k, p', w') y = + let a = (y - evalk k p' x) / evalk k w' x + in (k + 1, p' + scale 0 a w', scale 1 1 w' - scale 0 x w') +{-# INLINABLE hermite #-} diff --git a/test/Interpolation.hs b/test/Interpolation.hs new file mode 100644 index 0000000..967d32a --- /dev/null +++ b/test/Interpolation.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module Interpolation (testSuite) where + +import Data.Map +import Data.Poly hiding (scale) +import Data.Poly.Interpolation +import Test.Tasty +import Test.Tasty.QuickCheck + +import TestUtils () + +testSuite :: TestTree +testSuite = localOption (QuickCheckMaxSize 10) $ testGroup "Interpolation" + [ testProperty "lagrange interpolates" prop_lagrange + , testProperty "hermite interpolates" prop_hermite + , testProperty "lagrange == hermite" prop_lagrange_hermite + ] + +prop_lagrange :: Map Rational Rational -> Property +prop_lagrange xys = + let p = lagrange xys :: VPoly Rational + in conjoin $ fmap (\(x, y) -> eval p x === y) (toList xys) + +prop_hermite :: Map Rational (Rational, Rational, Rational, Rational) -> Property +prop_hermite xys = + let + p = hermite (fmap (\(y, y', y'', y''') -> [y, y', y'', y''']) xys) :: VPoly Rational + p' = deriv p + p'' = deriv p' + p''' = deriv p'' + in conjoin $ fmap (\(x, (y, y', y'', y''')) -> eval p x === y .&&. eval p' x === y' .&&. eval p'' x === y'' .&&. eval p''' x === y''') (toList xys) + +prop_lagrange_hermite :: Map Rational Rational -> Property +prop_lagrange_hermite xys = + let + p = lagrange xys :: VPoly Rational + q = hermite (fmap (\y -> [y]) xys) :: VPoly Rational + in p === q diff --git a/test/Main.hs b/test/Main.hs index 4b45cfd..af756d9 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,6 +8,7 @@ import qualified Dense import qualified DenseLaurent import qualified DFT import qualified Orthogonal +import qualified Interpolation #ifdef SupportSparse import qualified Multi import qualified MultiLaurent @@ -20,6 +21,7 @@ main = defaultMain $ testGroup "All" [ Dense.testSuite , DenseLaurent.testSuite , DFT.testSuite + , Interpolation.testSuite , Orthogonal.testSuite #ifdef SupportSparse , Sparse.testSuite