{-# LANGUAGE TypeFamilies, FlexibleContexts, TypeOperators #-}

-- Remove weight nodes from a weighted tree

module RmWeight
where

import Data.Generics.IG.Representable
import Data.Generics.IG.Tree

class RmWeight a where
  rmWeight :: a -> a

instance RmWeight Unit where
  rmWeight = id

instance RmWeight Int where
  rmWeight = id

instance (RmWeight a, RmWeight b) => RmWeight (a :*: b) where
  rmWeight (x :*: y) = rmWeight x :*: rmWeight y

instance (RmWeight a, RmWeight b) => RmWeight (a :+: b) where
  rmWeight (Inl x) = Inl $ rmWeight x
  rmWeight (Inr y) = Inr $ rmWeight y

dft_rmWeight :: (Representable a, RmWeight (Repr a)) => a -> a
dft_rmWeight = fromRepr . rmWeight . toRepr

instance (RmWeight a, RmWeight w) => RmWeight (Wtree a w) where
  rmWeight (Weight t _) = rmWeight t            -- an exception
  rmWeight t            = dft_rmWeight t

