Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JuvixReg transformation: initialize variables assigned in other branches #2650

Merged
merged 4 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import Juvix.Prelude

data TransformationId
= Identity
| SSA
| Cleanup
| SSA
| InitBranchVars
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -21,14 +22,15 @@ toCTransformations :: [TransformationId]
toCTransformations = [Cleanup]

toCairoTransformations :: [TransformationId]
toCairoTransformations = [Cleanup, SSA]
toCairoTransformations = [Cleanup, SSA, InitBranchVars]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
Identity -> strIdentity
SSA -> strSSA
Cleanup -> strCleanup
SSA -> strSSA
InitBranchVars -> strInitBranchVars

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
7 changes: 5 additions & 2 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ strCairoPipeline = "pipeline-cairo"
strIdentity :: Text
strIdentity = "identity"

strCleanup :: Text
strCleanup = "cleanup"

strSSA :: Text
strSSA = "ssa"

strCleanup :: Text
strCleanup = "cleanup"
strInitBranchVars :: Text
strInitBranchVars = "init-branch-vars"
34 changes: 33 additions & 1 deletion src/Juvix/Compiler/Reg/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,17 @@ data ForwardRecursorSig m c = ForwardRecursorSig
}

data BackwardRecursorSig m a = BackwardRecursorSig
{ _backwardFun :: Code -> a -> [a] -> m (a, Code),
{ -- | In `_backwardFun is a as`: `is = i : is'` is the instruction list
-- currently being processed (the head `i` is the processed instruction, the
-- tail `is'` contains the instructions after it); `a` is the accumulator
-- for `is'`; `as` contains the accumulator values for the branches (for
-- `Branch` and `Case` instructions, otherwise empty). For the `Case`
-- instruction, the accumulator for the default branch (if present) is the
-- last element of `as`.
_backwardFun :: Code -> a -> [a] -> m (a, Code),
-- | `backwardAdjust a` adjusts the accumulator value when going backwards
-- into a branch. See also `FoldSig` in `Asm.Extra.Recursors` for more
-- explanations.
_backwardAdjust :: a -> a
}

Expand Down Expand Up @@ -125,3 +135,25 @@ ifoldFM f a0 is0 =

ifoldF :: (Monoid a) => (a -> Instruction -> a) -> a -> Code -> a
ifoldF f a is = runIdentity (ifoldFM (\a' -> return . f a') a is)

ifoldBM :: forall a m. (Monad m) => (a -> [a] -> Instruction -> m a) -> a -> Code -> m a
ifoldBM f a0 is0 =
fst
<$> recurseB
BackwardRecursorSig
{ _backwardFun = go,
_backwardAdjust = id
}
a0
is0
where
go :: Code -> a -> [a] -> m (a, Code)
go is a as = case is of
i : _ -> do
a' <- f a as i
return (a', is)
[] ->
return (a, is)

ifoldB :: (a -> [a] -> Instruction -> a) -> a -> Code -> a
ifoldB f a is = runIdentity (ifoldBM (\a' as' -> return . f a' as') a is)
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Reg/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Juvix.Compiler.Reg.Data.TransformationId
import Juvix.Compiler.Reg.Transformation.Base
import Juvix.Compiler.Reg.Transformation.Cleanup
import Juvix.Compiler.Reg.Transformation.Identity
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA

applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
Expand All @@ -17,5 +18,6 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
appTrans = \case
Identity -> return . identity
SSA -> return . computeSSA
Cleanup -> return . cleanup
SSA -> return . computeSSA
InitBranchVars -> return . initBranchVars
91 changes: 91 additions & 0 deletions src/Juvix/Compiler/Reg/Transformation/InitBranchVars.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
module Juvix.Compiler.Reg.Transformation.InitBranchVars where

import Data.Functor.Identity
import Data.HashSet qualified as HashSet
import Data.List qualified as List
import Juvix.Compiler.Reg.Extra
import Juvix.Compiler.Reg.Transformation.Base

-- | Inserts assignments to initialize variables assigned in other branches.
-- Assumes the input is in SSA form (which is preserved).
initBranchVars :: InfoTable -> InfoTable
initBranchVars = mapT (const goFun)
where
goFun :: Code -> Code
goFun =
snd
. runIdentity
. recurseB
BackwardRecursorSig
{ _backwardFun = \is a as -> return (go is a as),
_backwardAdjust = const mempty
}
mempty

go :: Code -> HashSet VarRef -> [HashSet VarRef] -> (HashSet VarRef, Code)
go is a as = case is of
Branch InstrBranch {..} : is' -> case as of
[a1, a2] -> (a <> a', i' : is')
where
a' = a1 <> a2
a1' = HashSet.difference a' a1
a2' = HashSet.difference a' a2
i' =
Branch
InstrBranch
{ _instrBranchTrue = addInits a1' _instrBranchTrue,
_instrBranchFalse = addInits a2' _instrBranchFalse,
..
}
_ -> impossible
Case InstrCase {..} : is' ->
(a <> a', i' : is')
where
a' = mconcat as
as' = map (HashSet.difference a') as
n = length _instrCaseBranches
brs' = zipWithExact goBranch (take n as') _instrCaseBranches
def' = maybe Nothing (Just . addInits (List.last as')) _instrCaseDefault
i' =
Case
InstrCase
{ _instrCaseBranches = brs',
_instrCaseDefault = def',
..
}

goBranch :: HashSet VarRef -> CaseBranch -> CaseBranch
goBranch vars = over caseBranchCode (addInits vars)
i : _ ->
case getResultVar i of
Just v ->
(HashSet.insert v a <> mconcat as, is)
Nothing ->
(a <> mconcat as, is)
[] ->
(a <> mconcat as, is)

addInits :: HashSet VarRef -> Code -> Code
addInits vars is = map mk (toList vars) ++ is
where
mk :: VarRef -> Instruction
mk vref =
Assign
InstrAssign
{ _instrAssignResult = vref,
_instrAssignValue = Const ConstVoid
}

checkInitialized :: InfoTable -> Bool
checkInitialized tab = all (goFun . (^. functionCode)) (tab ^. infoFunctions)
where
goFun :: Code -> Bool
goFun = snd . ifoldB go (mempty, True)
where
go :: (HashSet VarRef, Bool) -> [(HashSet VarRef, Bool)] -> Instruction -> (HashSet VarRef, Bool)
go (v, b) ls i = case getResultVar i of
Just vref -> (HashSet.insert vref v', b')
Nothing -> (v', b')
where
v' = v <> mconcat (map fst ls)
b' = b && allEqual (map fst ls) && and (map snd ls)
5 changes: 5 additions & 0 deletions src/Juvix/Prelude/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ zip4Exact [] [] [] [] = []
zip4Exact (x1 : t1) (x2 : t2) (x3 : t3) (x4 : t4) = (x1, x2, x3, x4) : zip4Exact t1 t2 t3 t4
zip4Exact _ _ _ _ = error "zip4Exact"

allEqual :: (Eq a) => [a] -> Bool
lukaszcz marked this conversation as resolved.
Show resolved Hide resolved
allEqual = \case
a : as -> all (== a) as
[] -> True

--------------------------------------------------------------------------------
-- NonEmpty
--------------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion test/Reg/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ module Reg.Transformation where

import Base
import Reg.Transformation.Identity qualified as Identity
import Reg.Transformation.InitBranchVars qualified as InitBranchVars
import Reg.Transformation.SSA qualified as SSA

allTests :: TestTree
allTests =
testGroup
"JuvixReg transformations"
[ Identity.allTests,
SSA.allTests
SSA.allTests,
InitBranchVars.allTests
]
25 changes: 25 additions & 0 deletions test/Reg/Transformation/InitBranchVars.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module Reg.Transformation.InitBranchVars where

import Base
import Juvix.Compiler.Reg.Transformation
import Juvix.Compiler.Reg.Transformation.InitBranchVars
import Juvix.Compiler.Reg.Transformation.SSA
import Reg.Parse.Positive qualified as Parse
import Reg.Transformation.Base

allTests :: TestTree
allTests = testGroup "InitBranchVars" (map liftTest Parse.tests)

pipe :: [TransformationId]
pipe = [SSA, InitBranchVars]

liftTest :: Parse.PosTest -> TestTree
liftTest _testRun =
fromTest
Test
{ _testTransformations = pipe,
_testAssertion = \tab -> do
unless (checkSSA tab) $ error "check ssa"
unless (checkInitialized tab) $ error "check initialized",
_testRun
}
Loading