-- Copyright (C) 2018-2024 Jun Zhang <zhangjunphy[at]gmail[dot]com>
--
-- This file is a part of decafc.
--
-- decafc is free software: you can redistribute it and/or modify it under the
-- terms of the MIT (X11) License as described in the LICENSE file.
--
-- decafc is distributed in the hope that it will be useful, but WITHOUT ANY
-- WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE.  See the X11 license for more details.

module CFG.Optimizations.RemoveNoOp where

import CFG.Optimizations.Optimizer
import CFG.Types
import Control.Lens (use, uses, view, (%=), (%~), (&), (+=), (.=), (.~), (^.), _1, _2, _3)
import Control.Monad.Except
import Data.List (find)
import Data.Map.Strict qualified as Map
import SSA (SSA)
import SSA qualified
import Types
import Util.Graph qualified as G

removeNoOp :: CFGOptimizer ()
removeNoOp :: CFGOptimizer ()
removeNoOp = do
  CFG
cfg <- CFGOptimizer CFG
getCFG
  case CFG -> Maybe BBID
findNoOpNode CFG
cfg of
    Maybe BBID
Nothing -> () -> CFGOptimizer ()
forall a. a -> CFGOptimizer a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just BBID
bbid -> do
      BBID -> CFGOptimizer ()
removeNodeAndPatchPhi BBID
bbid
      CFGOptimizer ()
removeNoOp

-- find the first no-op node
findNoOpNode :: CFG -> Maybe BBID
findNoOpNode :: CFG -> Maybe BBID
findNoOpNode (CFG g :: Graph BBID BasicBlock CFGEdge
g@(G.Graph Map BBID BasicBlock
nodes Map (BBID, BBID) CFGEdge
edges) BBID
entry BBID
exit [Var]
_ MethodSig
_) =
  (BBID, BasicBlock) -> BBID
forall a b. (a, b) -> a
fst ((BBID, BasicBlock) -> BBID)
-> Maybe (BBID, BasicBlock) -> Maybe BBID
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((BBID, BasicBlock) -> Bool)
-> [(BBID, BasicBlock)] -> Maybe (BBID, BasicBlock)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((BBID -> BasicBlock -> Bool) -> (BBID, BasicBlock) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry BBID -> BasicBlock -> Bool
pred) (Map BBID BasicBlock -> [(BBID, BasicBlock)]
forall k a. Map k a -> [(k, a)]
Map.toList Map BBID BasicBlock
nodes)
  where
    noOpPred :: BBID -> BasicBlock -> Bool
noOpPred BBID
bbid BasicBlock
node
      | BBID
bbid BBID -> BBID -> Bool
forall a. Eq a => a -> a -> Bool
== BBID
entry = Bool
False
      | BBID
bbid BBID -> BBID -> Bool
forall a. Eq a => a -> a -> Bool
== BBID
exit = Bool
False
      | Bool
otherwise = [SSA] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SSA] -> Bool) -> [SSA] -> Bool
forall a b. (a -> b) -> a -> b
$ BasicBlock
node BasicBlock -> Getting [SSA] BasicBlock [SSA] -> [SSA]
forall s a. s -> Getting a s a -> a
^. Getting [SSA] BasicBlock [SSA]
#statements
    inboundPred :: BBID -> BasicBlock -> Bool
inboundPred BBID
bbid BasicBlock
_ = [(BBID, BBID, CFGEdge)] -> BBID
forall a. [a] -> BBID
forall (t :: * -> *) a. Foldable t => t a -> BBID
length (BBID -> Graph BBID BasicBlock CFGEdge -> [(BBID, BBID, CFGEdge)]
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> Graph ni nd ed -> [(ni, ni, ed)]
G.inBound BBID
bbid Graph BBID BasicBlock CFGEdge
g) BBID -> BBID -> Bool
forall a. Eq a => a -> a -> Bool
== BBID
1
    outboundPred :: BBID -> BasicBlock -> Bool
outboundPred BBID
bbid BasicBlock
_ = [(BBID, BBID, CFGEdge)] -> BBID
forall a. [a] -> BBID
forall (t :: * -> *) a. Foldable t => t a -> BBID
length (BBID -> Graph BBID BasicBlock CFGEdge -> [(BBID, BBID, CFGEdge)]
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> Graph ni nd ed -> [(ni, ni, ed)]
G.outBound BBID
bbid Graph BBID BasicBlock CFGEdge
g) BBID -> BBID -> Bool
forall a. Eq a => a -> a -> Bool
== BBID
1
    outEdgePred :: BBID -> BasicBlock -> Bool
outEdgePred BBID
bbid BasicBlock
_ =
      let (BBID
_, BBID
bbidOut, CFGEdge
edgeOut) = [(BBID, BBID, CFGEdge)] -> (BBID, BBID, CFGEdge)
forall a. HasCallStack => [a] -> a
head (BBID -> Graph BBID BasicBlock CFGEdge -> [(BBID, BBID, CFGEdge)]
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> Graph ni nd ed -> [(ni, ni, ed)]
G.outBound BBID
bbid Graph BBID BasicBlock CFGEdge
g)
       in case CFGEdge
edgeOut of
            CFGEdge
SeqEdge -> Bool
True
            CFGEdge
_ -> Bool
False
    pred :: BBID -> BasicBlock -> Bool
pred BBID
bbid BasicBlock
node =
      BBID -> BasicBlock -> Bool
noOpPred BBID
bbid BasicBlock
node
        Bool -> Bool -> Bool
&& BBID -> BasicBlock -> Bool
inboundPred BBID
bbid BasicBlock
node
        Bool -> Bool -> Bool
&& BBID -> BasicBlock -> Bool
outboundPred BBID
bbid BasicBlock
node
        Bool -> Bool -> Bool
&& BBID -> BasicBlock -> Bool
outEdgePred BBID
bbid BasicBlock
node

removeNodeAndPatchPhi :: BBID -> CFGOptimizer ()
removeNodeAndPatchPhi :: BBID -> CFGOptimizer ()
removeNodeAndPatchPhi BBID
bbid = do
  -- no-op should have only 1 inbound and 1 outbound
  (CFG Graph BBID BasicBlock CFGEdge
g BBID
_ BBID
_ [Var]
_ MethodSig
_) <- CFGOptimizer CFG
getCFG
  let inbound :: [(BBID, BBID, CFGEdge)]
inbound = BBID -> Graph BBID BasicBlock CFGEdge -> [(BBID, BBID, CFGEdge)]
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> Graph ni nd ed -> [(ni, ni, ed)]
G.inBound BBID
bbid Graph BBID BasicBlock CFGEdge
g
  let outbound :: [(BBID, BBID, CFGEdge)]
outbound = BBID -> Graph BBID BasicBlock CFGEdge -> [(BBID, BBID, CFGEdge)]
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> Graph ni nd ed -> [(ni, ni, ed)]
G.outBound BBID
bbid Graph BBID BasicBlock CFGEdge
g
  let (BBID
bbidIn, BBID
_, CFGEdge
edgeIn) = [(BBID, BBID, CFGEdge)] -> (BBID, BBID, CFGEdge)
forall a. HasCallStack => [a] -> a
head [(BBID, BBID, CFGEdge)]
inbound
  let (BBID
_, BBID
bbidOut, CFGEdge
edgeOut) = [(BBID, BBID, CFGEdge)] -> (BBID, BBID, CFGEdge)
forall a. HasCallStack => [a] -> a
head [(BBID, BBID, CFGEdge)]
outbound
  -- udpate destination inbound edge
  GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ()
forall a. GraphBuilder BBID BasicBlock CFGEdge a -> CFGOptimizer ()
updateCFG (GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ())
-> GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ()
forall a b. (a -> b) -> a -> b
$ do
    BBID -> BBID -> GraphBuilder BBID BasicBlock CFGEdge ()
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> ni -> GraphBuilder ni nd ed ()
G.deleteEdge BBID
bbidIn BBID
bbid
    BBID -> BBID -> GraphBuilder BBID BasicBlock CFGEdge ()
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> ni -> GraphBuilder ni nd ed ()
G.deleteEdge BBID
bbid BBID
bbidOut
    BBID -> BBID -> CFGEdge -> GraphBuilder BBID BasicBlock CFGEdge ()
forall ni ed nd.
(Eq ni, Ord ni) =>
ni -> ni -> ed -> GraphBuilder ni nd ed ()
G.addEdge BBID
bbidIn BBID
bbidOut CFGEdge
edgeIn
  -- patch Phi in successor nodes
  GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ()
forall a. GraphBuilder BBID BasicBlock CFGEdge a -> CFGOptimizer ()
updateCFG (GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ())
-> GraphBuilder BBID BasicBlock CFGEdge () -> CFGOptimizer ()
forall a b. (a -> b) -> a -> b
$ do
    BBID
-> (BasicBlock -> BasicBlock)
-> GraphBuilder BBID BasicBlock CFGEdge ()
forall ni nd ed.
(Eq ni, Ord ni) =>
ni -> (nd -> nd) -> GraphBuilder ni nd ed ()
G.adjustNode BBID
bbidOut (BBID -> BasicBlock -> BasicBlock
patchBasicBlock BBID
bbidIn)
    BBID -> GraphBuilder BBID BasicBlock CFGEdge ()
forall ni nd ed. (Eq ni, Ord ni) => ni -> GraphBuilder ni nd ed ()
G.deleteNode BBID
bbid
  where
    isSeqEdge :: CFGEdge -> Bool
isSeqEdge CFGEdge
SeqEdge = Bool
True
    isSeqEdge CFGEdge
_ = Bool
False
    patchPhi :: BBID -> SSA.SSA -> SSA.SSA
    patchPhi :: BBID -> SSA -> SSA
patchPhi BBID
bbidIn (SSA.Phi Var
dst [(Var, BBID)]
predecessors) =
      let replace :: (Var, BBID) -> (Var, BBID)
replace (Var
var, BBID
bbid') = if BBID
bbid' BBID -> BBID -> Bool
forall a. Eq a => a -> a -> Bool
== BBID
bbid then (Var
var, BBID
bbidIn) else (Var
var, BBID
bbid')
       in Var -> [(Var, BBID)] -> SSA
SSA.Phi Var
dst ([(Var, BBID)] -> SSA) -> [(Var, BBID)] -> SSA
forall a b. (a -> b) -> a -> b
$ (Var, BBID) -> (Var, BBID)
replace ((Var, BBID) -> (Var, BBID)) -> [(Var, BBID)] -> [(Var, BBID)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Var, BBID)]
predecessors
    patchPhi BBID
_ SSA
ssa = SSA
ssa
    patchBasicBlock :: BBID -> BasicBlock -> BasicBlock
    patchBasicBlock :: BBID -> BasicBlock -> BasicBlock
patchBasicBlock BBID
bbidIn BasicBlock
node = BasicBlock
node BasicBlock -> (BasicBlock -> BasicBlock) -> BasicBlock
forall a b. a -> (a -> b) -> b
& ASetter BasicBlock BasicBlock [SSA] [SSA]
#statements ASetter BasicBlock BasicBlock [SSA] [SSA]
-> ([SSA] -> [SSA]) -> BasicBlock -> BasicBlock
forall s t a b. ASetter s t a b -> (a -> b) -> s -> t
%~ (SSA -> SSA) -> [SSA] -> [SSA]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BBID -> SSA -> SSA
patchPhi BBID
bbidIn)