module RecFuse (averageFirst2N) where

newtype Fix f = Fix (f (Fix f))

data ListF a b = Nil | Cons a b

instance Functor (ListF a) where
  fmap _ Nil = Nil
  fmap f (Cons x xs) = Cons x (f xs)
  {-# INLINE fmap #-}

type StrictFoldFun f a = forall x . f x -> a -> a

sum' :: StrictFoldFun (ListF Int) Int
sum' Nil !acc = acc
sum' (Cons !x _) !acc = (x + acc)
{-# INLINE sum' #-}

length' :: StrictFoldFun (ListF Int) Int
length' Nil !acc = acc
length' (Cons _ _) !acc = (1 + acc)
{-# INLINE length' #-}

type FoldShape f = forall x . f (x -> x) -> (x -> x)

listShape :: FoldShape (ListF a)
listShape Nil acc = acc
listShape (Cons _ xs) acc = xs acc
{-# INLINE listShape #-}

combine :: StrictFoldFun f a -> StrictFoldFun f b -> StrictFoldFun f (a, b)
combine f g x (!a, !b) = (f x a, g x b)
{-# INLINE combine #-}

cata :: Functor f => (f a -> a) -> Fix f -> a
cata f = go
  where
    go (Fix x) = f (fmap go x)
{-# NOINLINE[3] cata #-}


strictFold :: (Functor f) => StrictFoldFun f a -> FoldShape f -> a -> (Fix f) -> a
strictFold f shape z = ($ z) . cata go
  where
    go x !acc = shape x (f x acc)
{-# INLINE strictFold #-}

ana :: Functor f => (a -> f a) -> a -> Fix f
ana f = go
  where
    go x = Fix (fmap go (f x))
{-# NOINLINE[3] ana #-}

hylo :: Functor f => (f b -> b) -> (a -> f a) -> a -> b
hylo f g = h where h = f . fmap h . g
{-# INLINE hylo #-}

{-# RULES "hylo/cata_ana" forall f g x. cata f (ana g x) = hylo f g x #-}

nt :: Functor f => (forall x . f x -> g x) -> Fix f -> Fix g
nt f = cata (Fix . f)
{-# NOINLINE[3] nt #-}

{-# RULES "nt/cata" forall f (g :: forall z . t z -> t' z) (x :: Fix t).  cata f (nt g x) = cata (f . g) x #-}

enumTo :: Int -> Fix (ListF Int)
enumTo n = ana go 0
  where
    go m
      | n < m = Nil
      | otherwise = Cons m (m + 1)
{-# INLINE enumTo #-}

double :: Fix (ListF Int) -> Fix (ListF Int)
double = nt go
  where
    go Nil = Nil
    go (Cons x xs) = Cons (x+x) xs
{-# INLINE double #-}

average :: Fix (ListF Int) -> Double
average xs = fromIntegral s / fromIntegral l
  where 
    (s, l) = strictFold (combine sum' length') listShape (0,0) xs
{-# INLINE average #-}

averageFirst2N :: Int -> Double
averageFirst2N = average . double . enumTo