{-# LANGUAGE BangPatterns #-}
-- |
-- Module      : Number.Polynomial
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good

module Number.Polynomial
	( Monomial(..)
	-- * polynomial operations
	, Polynomial
	, toList
	, fromList
	, addPoly
	, subPoly
	, mulPoly
	, squarePoly
	, expPoly
	, divPoly
	, negPoly
	) where

import Data.List (intercalate, sort)
import Data.Vector ((!), Vector)
import qualified Data.Vector as V
import Control.Arrow (first)

data Monomial = Monomial {-# UNPACK #-} !Int !Integer
	deriving (Eq)

data Polynomial = Polynomial (Vector Monomial)
	deriving (Eq)

instance Ord Monomial where
	compare (Monomial w1 v1) (Monomial w2 v2) =
		case compare w1 w2 of
			EQ -> compare v1 v2
			r  -> r

instance Show Monomial where
	show (Monomial w v) = show v ++ "x^" ++ show w

instance Show Polynomial where
	show (Polynomial p) = intercalate "+" $ map show $ V.toList p

toList :: Polynomial -> [Monomial]
toList (Polynomial p) = V.toList p

fromList :: [Monomial] -> Polynomial
fromList = Polynomial . V.fromList . reverse . sort . filterZero
	where
		filterZero = filter (\(Monomial _ v) -> v /= 0)

getWeight :: Polynomial -> Int -> Maybe Integer
getWeight (Polynomial p) n = look 0
	where
		plen = V.length p
		look !i
			| i >= plen = Nothing
			| otherwise =
				let (Monomial w v) = p ! i in
				case compare w n of
					LT -> Nothing
					EQ -> Just v
					GT -> look (i+1)
		

mergePoly :: (Integer -> Integer -> Integer) -> Polynomial -> Polynomial -> Polynomial
mergePoly f (Polynomial p1) (Polynomial p2) = fromList $ loop 0 0
	where
		l1 = V.length p1
		l2 = V.length p2
		loop !i1 !i2
			| i1 == l1 && i2 == l2 = []
			| i1 == l1             = (p2 ! i2) : loop i1 (i2+1)
			| i2 == l2             = (p1 ! i1) : loop (i1+1) i2
			| otherwise            =
				let (coef, i1inc, i2inc) = addCoef (p1 ! i1) (p2 ! i2) in
				coef : loop (i1+i1inc) (i2+i2inc)
		addCoef m1@(Monomial w1 v1) (Monomial w2 v2) =
			case compare w1 w2 of
				LT -> (Monomial w2 (f 0 v2), 0, 1)
				EQ -> (Monomial w1 (f v1 v2), 1, 1)
				GT -> (m1, 1, 0)

addPoly :: Polynomial -> Polynomial -> Polynomial
addPoly = mergePoly (+)

subPoly :: Polynomial -> Polynomial -> Polynomial
subPoly = mergePoly (-)

negPoly :: Polynomial -> Polynomial
negPoly (Polynomial p) = Polynomial $ V.map negateMonomial p
	where negateMonomial (Monomial w v) = Monomial w (-v)

mulPoly :: Polynomial -> Polynomial -> Polynomial
mulPoly p1@(Polynomial v1) p2@(Polynomial v2) =
	fromList $ filter (\(Monomial _ v) -> v /= 0) $ map (\i -> Monomial i (c i)) $ reverse [0..(m+n)]
	where
		(Monomial m _) = v1 ! 0
		(Monomial n _) = v2 ! 0
		c r = foldl (\acc i -> (b $ r-i) * (a $ i) + acc) 0 [0..r]
			where
				a = maybe 0 id . getWeight p1
				b = maybe 0 id . getWeight p2

squarePoly :: Polynomial -> Polynomial
squarePoly p = p `mulPoly` p

expPoly :: Polynomial -> Integer -> Polynomial
expPoly p e = loop p e
	where
		loop t 0 = t
		loop t n = loop (squarePoly t) (n-1)

divPoly :: Polynomial -> Polynomial -> (Polynomial, Polynomial)
divPoly p1 p2@(Polynomial pp2) = first fromList $ divLoop p1
	where divLoop d1@(Polynomial pp1)
		| V.null pp1 = ([], d1)
		| otherwise  =
			let (Monomial w1 v1) = pp1 ! 0 in
			let (Monomial w2 v2) = pp2 ! 0 in
			let w = w1 - w2 in
			let (v,r) = v1 `divMod` v2 in
			if w >= 0 && r == 0
				then
					let mono = (Monomial w v) in
					let remain = d1 `subPoly` (p2 `mulPoly` (fromList [mono])) in
					let (l, finalRem) = divLoop remain in
					(mono : l, finalRem)
				else
					([], d1)