Re: the extension to continuation based contification

Top Pagina
Bijlagen:
Bericht als e-mail
+ (text/plain)
Delete this message
Reply to this message
Auteur: Stephen Weeks
Datum:  
Aan: MLton
Onderwerp: Re: the extension to continuation based contification

> Is it possible to only introduce the "return to X" continuations if the
> analysis hits top for the function X? Once it is determined that some
> top-level function (say F) has more than one continuation, there's no hope
> for contifying it; so, rather than propagating that non-contification
> through to functions whose continuations are F, establish a "return to F"
> continuation for the analysis of the body of F, maybe allowing other
> functions to be contified into F.


This makes sense. I don't quite see how to get a least-fixed-point computation
out of the idea though. I'm thinking about it.

> In any event, I'd be interested in looking at the implementation of the
> continuation contification pass for MLton, at least to see if my quick
> sketch of translating Reppy's analysis to the CPS IL is close.


Here's the current incarnation. It's a bit messy because it handles all the
possibilities so I could do that benchmarking, and because I started adding the
"return to F" stuff, but it's not operational yet.

--------------------------------------------------------------------------------

(* Copyright (C) 1997-1999 NEC Research Institute.
 * Please see the file LICENSE for license information.
 *)
(*
 * This pass combines two analyses and transformations.
 *
 * The first analysis, "call based" notices if a toplevel function is called
 *       1. from one place outside itself
 *   and 2. all calls to itself within itself are tail
 * If (1) and (2) are true, then that function can be made local at the point
 * of its outer call.
 *
 * The second analysis is based on the following paper.
 *   Local CPS conversion in a direct-style compiler.  John Reppy.
 * It determines the set of continuations that a function is called with.  If
 * that set is a singleton, then the function can be prefixed onto the
 * continuation it is called with.
 *)


functor Contify (S: CONTIFY_STRUCTS): CONTIFY =
struct

open S
open Dec Transfer

structure ContSet =
   struct
      datatype set =
     Empty
       | MainCont
       | One of Jump.t
       | Return of Func.t
       | All


      fun layoutSet s =
     let open Layout
     in case s of
        Empty => str "{}"
      | MainCont => str "MainCont"
      | One j => seq [str "{ ", Jump.layout j, str " }"]
      | Return f => seq [str "{ ", Func.layout f, str " }"]
      | All => str "All"
     end


      datatype t = T of {set: set ref,
             lessThan: t list ref}


      fun set (T {set, ...}) = !set


      fun layout (T {set, ...}) = layoutSet (!set)


      fun new () = T {set = ref Empty,
              lessThan = ref []}


      fun up (T {set, lessThan}, s) =
     let
        fun doit s = (set := s; List.foreach (!lessThan, fn c => up (c, s)))
     in case (!set, s) of
        (_, Empty) => ()
      | (All, _) => ()
      | (Empty, k) => doit k
      | (MainCont, MainCont) => ()
      | (One j, One j') => if Jump.equals (j, j') then () else doit All
      | (Return f, Return f') => if Func.equals (f, f') then () else doit All
      | _ => doit All
     end


      val up = Trace.trace2 ("up", layout, layoutSet, Unit.layout) up


      val op <= =
     fn (T {set, lessThan}, c) =>
          (List.push (lessThan, c)
           ; up (c, !set))

    
      fun addReturn (c, f) = up (c, Return f)
      fun addJump (c, j) = up (c, One j)
      fun addMain c = up (c, MainCont)
      fun makeTop c = up (c, All)
   end

    
structure CallInfo =
   struct
      datatype t =
     NoOuterCalls
       | OneOuterCall
       | NotCont


      val toString =
     fn NoOuterCalls => "NoOuterCalls"
      | OneOuterCall => "OneOuterCall"
      | NotCont => "NotCont"


      val layout = Layout.str o toString
   end


structure Graph = DirectedGraph
structure Node = Graph.Node

fun contify (program as Program.T {datatypes, globals, functions, main}) =
   let
      val strategy = !Control.contifyStrategy
   in case strategy of
      Control.None => program
    | _ => 
     let
        datatype 'a replace =
           None
         | OneCall of 'a
         | OneCont of 'a
        val {get = funcInfo:
         Func.t -> {
                callers: Func.t list ref,
                callInfo: CallInfo.t ref,
                canPrefix: bool ref,
                contSet: ContSet.t,
                isLocal: bool ref,
                nested: Func.t list ref,
                node: Graph.Node.t option ref,
                possiblePrefixes: Func.t list ref,
                prefixes: Func.t list option ref,
                replace: {args: (Var.t * Type.t) list,
                      body: Exp.t,
                      jump: Jump.t} option ref
                }} =
           Property.get (Func.plist,
                 fn _ => {callers = ref [],
                      callInfo = ref CallInfo.NoOuterCalls,
                      canPrefix = ref false,
                      contSet = ContSet.new (),
                      isLocal = ref false,
                      nested = ref [],
                      node = ref NONE,
                      possiblePrefixes = ref [],
                      prefixes = ref NONE,
                      replace = ref NONE})
        (* Compute the contSet and callInfo for each function.
         * The contSet is an over-approximation to the set of continuations
         * with which the function is called.
         * The callInfo tells at how many places outside of itself each
         * function is called.
         *)
        val _ = ContSet.addMain (#contSet (funcInfo main))
        val _ =
           List.foreach
           (functions, fn {name, body, ...} =>
        let val {callInfo, contSet = c, ...} = funcInfo name
        in Exp.foreachCall
           (body, fn {func, cont, ...} =>
            if Func.equals(name, func)
               then
              case cont of
                 NONE => ()
               | SOME _ => (callInfo := CallInfo.NotCont
                    ; ContSet.makeTop c)
            else
               let
              val {callers, callInfo = callInfo', contSet = c', ...} =
                 funcInfo func
              val _ =
                 let
                datatype z = datatype CallInfo.t
                 in case !callInfo' of
                NoOuterCalls => callInfo' := OneOuterCall
                  | OneOuterCall => callInfo' := NotCont
                  | NotCont => ()
                 end
              val _ = List.push (callers, name)
               in case cont of
              NONE => ContSet.<= (c, c')
            (* ContSet.addReturn (c, name) *)
            | SOME j => ContSet.addJump (c', j)
               end)
        end)
        (* Record for each jump the functions that might be turned into
         * continuations as its prefixes.
         * Record for each function the functions that might prefix its
         * return.
         *)
        val {get = jumpInfo: Jump.t ->
         {possiblePrefixes: Func.t list ref,
          prefixes: Func.t list option ref}} =
           Property.get (Jump.plist, fn _ => {possiblePrefixes = ref [],
                          prefixes = ref NONE})
        val todo = ref []
        val _ =
           List.foreach
           (functions, fn {name, ...} =>
        let
           val {contSet, ...} = funcInfo name
           fun doit (possiblePrefixes, prefixes) =
              let
             val _ =
                case !possiblePrefixes of
                   [] => List.push (todo,
                        (possiblePrefixes, prefixes))
                 | _ => ()
              in List.push (possiblePrefixes, name)
              end


        in case ContSet.set contSet of
           ContSet.One j => 
              let val {possiblePrefixes, prefixes, ...} = jumpInfo j
              in doit (possiblePrefixes, prefixes)
              end
         | ContSet.Return f =>
              let val {possiblePrefixes, prefixes, ...} = funcInfo f
              in doit (possiblePrefixes, prefixes)
              end
         | _ => ()
        end)
        (* Strongly connected components of a group of functions. *)
        fun sccs (fs: Func.t list) =
           let
          val g = Graph.new ()
          val {get = nodeFunc, set = setNodeFunc} =
             Property.getSetOnce
             (Node.plist, Property.initRaise ("func", Node.layout))
           val _ =
              List.foreach (fs, fn f =>
                    let val {node, ...} = funcInfo f
                       val n = Graph.newNode g
                       val _ = setNodeFunc (n, f)
                       val _ = node := SOME n
                    in ()
                    end)
           (* Build the call graph.
            * Edges go from nodes to the callers.
            *)
           val _ =
              List.foreach
              (fs, fn f =>
               let val {node, callers, ...} = funcInfo f
              val from = valOf (!node)
               in List.foreach
              (!callers, fn f' =>
               let val {node, ...} = funcInfo f'
               in case !node of
                  NONE => ()
                | SOME to =>
                 (Graph.addEdge (g, {from = from, to = to})
                  ; ())
               end)
               end)
           val _ = List.foreach (fs, fn f => #node (funcInfo f) := NONE)
           val nss = Graph.stronglyConnectedComponents g
           in List.map (nss, fn ns => List.revMap (ns, nodeFunc))
           end

    
        (* For each collection of functions that are going to prefix a cont,
         * do a strongly connected components computation.  In order for the
         * functions to be contified, there can be at most one function in
         * each component that is called from outside.  This is due to the
         * fact that mutually recursive continuations cannot be directly
         * declared in CPS -- the only way to do so is to nest one within the
         * other.  Thus only one can be "exported".
         *)
        exception Nope
        (* sccHeads returns the head function for each scc in the list of
         * functions.  It raises Nope if there is an scc with more than one
         * function called from outside the scc.
         *)
        fun sccHeads (fs: Func.t list): Func.t list =
           List.map
           (sccs fs, fn fs =>
        let
           fun setLocal b =
              List.foreach (fs, fn f => #isLocal (funcInfo f) := b)
           val _ = setLocal true
           val outsideCaller = ref NONE
           val _ =
              List.foreach
              (fs, fn f =>
               let val {callers, ...} = funcInfo f
               in if List.exists (!callers, fn f =>
                      not (!(#isLocal (funcInfo f))))
                 then (case !outsideCaller of
                      SOME _ => raise Nope
                    | NONE => outsideCaller := SOME f)
              else ()
               end)
           val _ = setLocal false
        in case !outsideCaller of
           NONE => Error.bug "no outside caller"
         | SOME f =>
              let
             val {nested, ...} = funcInfo f
             val rest = List.removeFirst (fs, fn f' =>
                              Func.equals (f, f'))
             val _ = nested := sccHeads rest
              in f
              end
        end)
        val _ =
           List.foreach
           (!todo, fn (possiblePrefixes, prefixes) =>
        let
           val fs = !possiblePrefixes
           val _ = prefixes := SOME (sccHeads fs)
           val _ = List.foreach (fs, fn f =>
                     #canPrefix (funcInfo f) := true)
        in ()
        end handle Nope => ())
        (* Diagnostics. *)
        val _ =
           if false
          then 
             let
            val _ =
               Program.layouts
               (program, fn l => 
                (Layout.output (l, Out.error)
                 ; Out.newline Out.error))
            val old = ref 0
            val new = ref 0
            val newNo = ref 0
            val same = ref 0
            val sameNo = ref 0
             in List.foreach
            (functions, fn {name, ...} =>
             let
                val {callInfo, canPrefix, contSet, ...} = funcInfo name
                fun doit (r, s) =
                   (Int.inc r
                ; if false
                     then print (concat [s, " ",
                             Func.toString name, "\n"])
                  else ())
                datatype z = datatype CallInfo.t
                datatype z = datatype ContSet.set
             in case (!callInfo, ContSet.set contSet) of
                (OneOuterCall, One _) =>
                   if !canPrefix
                  then doit (same, "same")
                   else doit (sameNo, "sameNo")
              | (OneOuterCall, _) => doit (old, "old")
              | (_, One _) =>
                   if !canPrefix
                  then doit (new, "new")
                   else doit (newNo, "newNo")
              | _ => ()
             end)
            ; print (concat
                 ["num functions ",
                  Int.toString (List.length functions),
                  "  same ", Int.toString (!same),
                  "  sameNo ", Int.toString (!sameNo),
                  "  num new ", Int.toString (!new),
                  "  num newNo ", Int.toString (!newNo),
                  "  num old ", Int.toString (!old),
                  "\n"])
             end
           else ()
        (* For functions turned into continuations, record their
         * args, body, and new name.
         *)
        val _ =
           List.foreach
           (functions, fn {name, args, body, ...} =>
        let val {callInfo, canPrefix, replace, ...} = funcInfo name
           val oneCall =
              case !callInfo of
             CallInfo.OneOuterCall => true
               | _ => false
           val oneCont = !canPrefix
        in if (case strategy of
              Control.Both => oneCall orelse oneCont
            | Control.Call => oneCall
            | Control.Cont => oneCont
            | Control.None => false)
              then
             replace :=
             SOME {args = args,
                   body = body,
                   jump = Jump.newString (Func.originalName name)}
           else ()
        end)
        (* Walk over all functions, removing those that aren't toplevel, and
         * descending those that are, inserting local functions
         * where necessary.
         *  - turn tail calls into nontail calls
         *  - turn returns into jumps
         *)
        fun walkExp (f: Func.t, e: Exp.t, c: Jump.t option): Exp.t =
           let
          val {decs, transfer} = Exp.dest e
          val decs = 
             List.fold
             (rev decs, [], fn (d, ds) =>
              case d of
             Bind _ => d :: ds
               | Fun {name, args, body} =>
                Fun {name = name,
                 args = args,
                 body = walkExp (f, body, c)}
                :: (if (case strategy of
                       Control.Both => true
                     | Control.Cont => true
                     | _ => false)
                   then
                      let val {prefixes, ...} = jumpInfo name
                      in case !prefixes of
                     NONE => ds
                       | SOME fs => nest (fs, SOME name, ds)
                      end
                else ds)
               | HandlerPush h => HandlerPush h :: ds
               | HandlerPop => HandlerPop :: ds)
          fun make transfer = Exp.make {decs = decs,
                        transfer = transfer}
           in
          case transfer of
             Call {func, args, cont} =>
            let
               val newCont: Jump.t option =
                  case cont of
                 NONE => c
                   | SOME _ => cont
               val {callInfo, canPrefix, replace, ...} =
                  funcInfo func
            in
               case !replace of
                  NONE =>
                 make (Call {func = func,
                         args = args,
                         cont = newCont})
                | SOME {jump, args = formals, body} =>
                 let
                    val decs =
                       if !callInfo = CallInfo.OneOuterCall
                      andalso not (Func.equals (f, func))
                      andalso (case strategy of
                              Control.Both =>
                             not (!canPrefix)
                            | Control.Call => true
                            | _ => false)
                      then
                         decs
                         @ [Fun {name = jump,
                             args = formals,
                             body = walkExp (func, body,
                                     newCont)}]

                    
                       else decs
                 in
                    Exp.make
                    {decs = decs,
                     transfer = Jump {dst = jump, args = args}}
                 end
            end
           | Return xs =>
            make (case c of
                 NONE => transfer
                   | SOME c => Jump {dst = c, args = xs})
           | _ => make transfer
           end
        and nest (fs: Func.t list,
              cont: Jump.t option,
              ds: Dec.t list): Dec.t list =
           List.fold
           (rev fs, ds, fn (f, ds) =>
        let
           val {replace, nested, prefixes, ...} = funcInfo f
           val fs =
              case !prefixes of
             NONE => !nested
               | SOME fs => List.appendRev (fs, !nested)
           val {jump, args, body} = valOf (!replace)
           val {decs, transfer} = Exp.dest (walkExp (f, body, cont))
           val body = Exp.make {decs = nest (!nested, cont, decs),
                    transfer = transfer}
        in Fun {name = jump, args = args, body = body} :: ds
        end)
        val shrinkExp = shrinkExp globals
        val functions =
           List.fold
           (functions, [], fn ({name, args, body, returns}, functions) =>
        let val {replace, prefixes, ...} = funcInfo name
        in case !replace of
           NONE =>
              let
             val body = shrinkExp (walkExp (name, body, NONE))
             val body =
                case !prefixes of
                   NONE => body
                 | SOME fs =>
                  let val {decs, transfer} = Exp.dest body
                  in Exp.make {decs = nest (fs, NONE, decs),
                           transfer = transfer}
                  end
              in {name = name, args = args, returns = returns,
              body = body}
              end :: functions
         | _ => functions
        end)
        val program =
           Program.T {datatypes = datatypes,
              globals = globals,
              functions = functions,
              main = main}
        val _ = Program.clear program
     in
        program
     end
   end


end