{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module LambdaLift(lambdalift) where

-- The lambda lifter - lifts lambda out of a function and makes named 
-- functions from them.

import Language
import Debug.Trace

lambdalift :: Program -> Program
lambdalift [] = []
lambdalift (x:xs) = liftdecl x ++ lambdalift xs

liftdecl :: Decl -> [Decl]
liftdecl (FunBind (f,l,nm,ty,ops,(Defined def)) comm ority)
	 = let (defs,def') = lift nm ty def in
	       (FunBind (f,l,nm,ty,ops,(Defined def')) comm ority):
	       map mktop defs
  where mktop (n,ty,exp) = FunBind (f,l,n,ty,[Inline, Generated],(Defined exp)) "" ority
liftdecl x = [x]

-- Take an expression, return the modified expression,
-- and a list of lambda lifted definitions
-- FIXME: Would be much neater with a Liftable type class then we could
-- automatically have lifting over lists of syntax etc.
lift :: Name -> Type -> Expr Name -> ([(Name, Type, Expr Name)],Expr Name)
lift nm ty exp = lift' [] [] exp where
   -- Boring recursive cases
   -- There really should be a HOF for this pattern!
   lift' defs locs (Lambda ivs args exp) = 
       let (defs',exp') = lift' defs args exp in
	   (defs',Lambda ivs args exp')
   lift' defs locs (Bind n ty e1 e2) =
       let (defs',e1') = lift' defs locs e1
	   (defs'',e2') = lift' defs' ((n,ty):locs) e2 in
	   (defs'',Bind n ty e1' e2')
   lift' defs locs (Declare f l (n,loc) ty exp) = 
       let (defs',exp') = lift' defs ((n,ty):locs) exp in
	   (defs',Declare f l (n,loc) ty exp')
   lift' defs locs (Return exp) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',Return exp')
   lift' defs locs (Assign lval exp) = 
       let (defs',exp') = lift' defs locs exp
	   (defs'',lval') = liftlval defs' locs lval in
	   (defs'',Assign lval' exp')
   lift' defs locs (AssignOp op lval exp) = 
       let (defs',exp') = lift' defs locs exp
	   (defs'',lval') = liftlval defs' locs lval in
	   (defs'',AssignOp op lval' exp')
   lift' defs locs (AssignApp lval exp) = 
       let (defs',exp') = lift' defs locs exp
	   (defs'',lval') = liftlval defs' locs lval in
	   (defs'',AssignApp lval' exp')
   lift' defs locs (Seq e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',Seq e1' e2')
   lift' defs locs (Apply e es) =
       let (defs',e') = lift' defs locs e in
       let (defs'',es') = lifts' defs' locs es in
	   (defs'',Apply e' es')
   lift' defs locs (Partial b e es i) =
       let (defs',e') = lift' defs locs e in
       let (defs'',es') = lifts' defs' locs es in
	   (defs'',Partial b e' es' i)
   lift' defs locs (Foreign ty n estys) =
       let (defs',es') = lifts' defs locs (map fst estys) in
	   (defs',Foreign ty n (zip es' (map snd estys)))
   lift' defs locs (While e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',While e1' e2')
   lift' defs locs (DoWhile e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',DoWhile e1' e2')
   lift' defs locs (For i nm j lval e1 e2) =
       let (defs',e1') = lift' defs (fakevars++locs) e1 in
       let (defs'',e2') = lift' defs' (fakevars++locs) e2 in
       let (defs''',lval') = liftlval defs'' locs lval in
	   (defs''',For i nm j lval' e1' e2')
     -- fake variables to account for For introducing names into environement
     where fakevars = [(MN ("_",i), Prim Number), (MN ("_",j), UnknownType)]
   lift' defs locs (TryCatch e1 e2 e3 e4) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
       let (defs''',e3') = lift' defs'' locs e3 in
       let (defs'''',e4') = lift' defs''' locs e4 in
	   (defs''',TryCatch e1' e2' e3' e4')
   lift' defs locs (NewTryCatch e1 cs) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',cs') = liftCs defs' locs cs in
           (defs'', NewTryCatch e1' cs')
   lift' defs locs (Throw exp) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',Throw exp')
   lift' defs locs (Except e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',Except e1' e2')
   lift' defs locs (InferPrint exp ty s i) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',InferPrint exp' ty s i)
   lift' defs locs (PrintNum exp) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',PrintNum exp')
   lift' defs locs (PrintStr exp) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',PrintStr exp')
   lift' defs locs (PrintExc exp) = 
       let (defs',exp') = lift' defs locs exp in
	   (defs',PrintExc exp')
   lift' defs locs (Infix op e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',Infix op e1' e2')
   lift' defs locs (RealInfix op e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',RealInfix op e1' e2')
   lift' defs locs (InferInfix op e1 e2 t s i) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',InferInfix op e1' e2' t s i)
   lift' defs locs (CmpExcept op e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',CmpExcept op e1' e2')
   lift' defs locs (CmpStr op e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',CmpStr op e1' e2')
   lift' defs locs (Append e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',Append e1' e2')
   lift' defs locs (AppendChain es) =
       let (defs',es') = lifts' defs locs es in
	   (defs',AppendChain es')
   lift' defs locs (Unary op e1) =
       let (defs',e1') = lift' defs locs e1 in
	   (defs',Unary op e1')
   lift' defs locs (RealUnary op e1) =
       let (defs',e1') = lift' defs locs e1 in
	   (defs',RealUnary op e1')
   lift' defs locs (InferUnary op e1 t s i) =
       let (defs',e1') = lift' defs locs e1 in
	   (defs',InferUnary op e1' t s i)
   lift' defs locs (Coerce t1 t2 e1) =
       let (defs',e1') = lift' defs locs e1 in
	   (defs',Coerce t1 t2 e1')
   lift' defs locs (Case e1 alts) =
       let (defs',e1') = lift' defs locs e1
	   (defs'',alts') = liftalts defs' locs alts in
	   (defs'',Case e1' alts')
   lift' defs locs (If e1 e2 e3) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
       let (defs''',e3') = lift' defs'' locs e3 in
	   (defs''',If e1' e2' e3')
   lift' defs locs (Index e1 e2) =
       let (defs',e1') = lift' defs locs e1 in
       let (defs'',e2') = lift' defs' locs e2 in
	   (defs'',Index e1' e2')
   lift' defs locs (Field e1 n i t) =
       let (defs',e1') = lift' defs locs e1 in
	   (defs',Field e1' n i t)
   lift' defs locs (ArrayInit es) =
       let (defs',es') = lifts' defs locs es in
	   (defs',ArrayInit es')
   lift' defs locs (Annotation a e) =
       let (defs',e') = lift' defs locs e in
	   (defs',Annotation a e')

   -- The interesting case
   -- Make a new function of exp, add to defs, replace this instance
   -- with a partial application of that function to locals so far.
   -- Local variables in outer scope are passed as 'var', declared variables
   -- are passed as 'copy'
   lift' defs locs (Closure tys rty exp) = 
	   let locs' = renamelocs locs
	       (defs',exp') = lift' defs (locs'++tys) exp
	       (defs'',nm') = mkNewFun nm ((map ((,) Var) (reverse locs'))++
                                          (map ((,) Copy) tys)) 
                                       rty exp' defs' in
	       (defs'',Partial True (Global nm' (mangling (getty defs''))
                                           (argSpace (getty defs'')))
                       (map Loc [0..(length locs')-1]) 
		         (length tys))
     -- Give locals unique names to avoid clashes
      where getty ((_,ty,_):xs) = ty
            renamelocs xs = rl 0 xs
	    rl _ [] = []
            -- skip machine generated names, they won't be used
--            rl n (((MN _),_):xs) = rl n xs
	    rl n ((nm,ty):xs) = ((rln n nm,ty):(rl (n+1) xs))
	    rln n (UN s) = MN (s,n)
	    rln n (NS mod nm) = (NS mod (rln n nm))
	    rln n m@(MN _) = m

   -- Non recursive catch all
   lift' defs locs exp = (defs,exp)

   liftlval defs locs (AIndex a e) =  
       let (defs',e') = lift' defs locs e
	   (defs'',a') = liftlval defs' locs a in
	   (defs'',AIndex a' e')
   liftlval defs locs (AField a n i t) =  
       let (defs',a') = liftlval defs locs a in
	   (defs',AField a' n i t)
   liftlval defs locs x = (defs,x)

   liftalts defs locs [] = (defs,[])
   liftalts defs locs ((Alt t i es e):xs) =
       let (defs',es') = lifts' defs locs es
	   (defs'',e') = lift' defs' locs e
	   (defs''',xs') = liftalts defs'' locs xs in
	   (defs''',(Alt t i es' e'):xs')
   liftalts defs locs ((ArrayAlt es e):xs) =
       let (defs',es') = lifts' defs locs es
	   (defs'',e') = lift' defs' locs e
	   (defs''',xs') = liftalts defs'' locs xs in
	   (defs''',(ArrayAlt es' e'):xs')
   liftalts defs locs ((ConstAlt pt c e):xs) =
       let (defs',e') = lift' defs locs e
	   (defs'',xs') = liftalts defs' locs xs in
	   (defs'',(ConstAlt pt c e'):xs')
   liftalts defs locs ((Default e):xs) =
       let (defs',e') = lift' defs locs e
	   (defs'',xs') = liftalts defs' locs xs in
           (defs'',(Default e'):xs')

   -- See, a type class would avoid this sort of repetition by capturing
   -- the fact that Liftable a => Liftable [a].
   liftCs defs locs [] = (defs, [])
   liftCs defs locs (c:cs) = let (defs', c') = liftC defs locs c
                                 (defs'', cs') = liftCs defs' locs cs in
                                 (defs'', (c':cs'))

   liftC defs locs (Catch (Left (n, es)) e) =
       let (defs', es') = lifts' defs locs es
           (defs'', e') = lift' defs' locs e in
           (defs'', Catch (Left (n,es')) e')
   liftC defs locs (Catch (Right e1) e2) =
       let (defs', e1') = lift' defs locs e1
           (defs'', e2') = lift' defs' locs e2 in
           (defs'', Catch (Right e1') e2')

   lifts' defs locs [] = (defs,[])
   lifts' defs locs (x:xs) = 
       let (defs',x') = lift' defs locs x
	   (defs'',xs') = lifts' defs' locs xs in
	   (defs'',(x':xs'))

mkNewFun :: Name -> 
	    [(ArgType, (Name,Type))] -> Type -> 
            Expr Name -> [(Name, Type, Expr Name)]
	    -> ([(Name, Type, Expr Name)], Name)
mkNewFun basename locs rty exp defs = (newfun:defs, fnname)
   where newfun = (fnname, 
		   Fn (map (\x -> Nothing) locs) (map (snd.snd) locs) rty, 
		   def)
	 def = Lambda (map fst locs) (map snd locs)
	              (Annotation (LamBody (showuser basename)) exp)

         fnname = newname basename (length defs)
         newname (UN n) i = MN (n,i)
	 newname (NS s n) i = NS s (newname n i)
	 newname nm _ = error $ "Can't happen (mkNewFun, "++showuser nm++")"