From 2dfca50ab1bc173bc835a2973e676dd4e84225a3 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 27 Nov 2024 19:25:43 +0100 Subject: [PATCH 01/12] loop hoisting wip --- src/Juvix/Compiler/Core/Extra/Base.hs | 6 ++ src/Juvix/Compiler/Core/Extra/Utils.hs | 14 +++++ src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs | 11 ++-- .../Core/Transformation/ComputeTypeInfo.hs | 8 +++ .../Transformation/Optimize/LoopHoisting.hs | 59 +++++++++++++++++++ 5 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index f5a1b6b775..5d2a798b81 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -59,12 +59,18 @@ mkLambda i bi b = NLam (Lambda i bi b) mkLambda' :: Type -> Node -> Node mkLambda' ty = mkLambda Info.empty (mkBinder' ty) +mkLambda'' :: Binder -> Node -> Node +mkLambda'' = mkLambda Info.empty + mkLambdas :: [Info] -> [Binder] -> Node -> Node mkLambdas is bs n = foldl' (flip (uncurry mkLambda)) n (reverse (zipExact is bs)) mkLambdas' :: [Type] -> Node -> Node mkLambdas' tys n = foldl' (flip mkLambda') n (reverse tys) +mkLambdas'' :: [Binder] -> Node -> Node +mkLambdas'' bs n = foldl' (flip mkLambda'') n (reverse bs) + mkLetItem :: Text -> Type -> Node -> LetItem mkLetItem name ty = LetItem (mkBinder name ty) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 77d5d836d0..349bbae432 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -162,11 +162,25 @@ isDataValue = \case NCtr Constr {..} -> all isDataValue _constrArgs _ -> False +isFullyApplied :: Module -> Node -> Bool +isFullyApplied md node = case h of + NIdt Ident {..} + | Just ii <- lookupIdentifierInfo' md _identSymbol -> + length args >= ii ^. identifierArgsNum + _ -> True + where + (h, args) = unfoldApps' node + isFailNode :: Node -> Bool isFailNode = \case NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True _ -> False +isLambda :: Node -> Bool +isLambda = \case + NLam {} -> True + _ -> False + isTrueConstr :: Node -> Bool isTrueConstr = \case NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True diff --git a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs index 312c6cfe61..07c3e13fcb 100644 --- a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs +++ b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs @@ -38,7 +38,7 @@ computeFreeVarsInfo' lambdaMultiplier = umap go fvi = FreeVarsInfo . fmap (* lambdaMultiplier) - $ getFreeVars 1 _lambdaBody + $ getFreeVars' 1 _lambdaBody _ -> modifyInfo (Info.insert fvi) node where @@ -47,13 +47,13 @@ computeFreeVarsInfo' lambdaMultiplier = umap go foldr ( \NodeChild {..} acc -> Map.unionWith (+) acc $ - getFreeVars _childBindersNum _childNode + getFreeVars' _childBindersNum _childNode ) mempty (children node) - getFreeVars :: Int -> Node -> Map Index Int - getFreeVars bindersNum node = + getFreeVars' :: Int -> Node -> Map Index Int + getFreeVars' bindersNum node = Map.mapKeysMonotonic (\idx -> idx - bindersNum) . Map.filterWithKey (\idx _ -> idx >= bindersNum) $ getFreeVarsInfo node ^. infoFreeVars @@ -61,6 +61,9 @@ computeFreeVarsInfo' lambdaMultiplier = umap go getFreeVarsInfo :: Node -> FreeVarsInfo getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo +getFreeVars :: Node -> [Index] +getFreeVars = Map.keys . Map.filter (> 0) . (^. infoFreeVars) . getFreeVarsInfo + freeVarOccurrences :: Index -> Node -> Int freeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getFreeVarsInfo n ^. infoFreeVars)) diff --git a/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs b/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs index 4e41e5014f..e5972ccb50 100644 --- a/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs +++ b/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs @@ -5,6 +5,14 @@ import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Info.TypeInfo qualified as Info import Juvix.Compiler.Core.Transformation.Base +computeNodeType' :: Module -> BinderList Binder -> Node -> Type +computeNodeType' md bl node = rePis argtys' ty' + where + ty = computeNodeType md (mkLambdas'' (reverse (toList bl)) node) + (argtys, ty') = unfoldPi ty + argtys' = drop (length bl) argtys + +-- | Computes the type of a closed well-typed node. computeNodeType :: Module -> Node -> Type computeNodeType md = Info.getNodeType . computeNodeTypeInfo md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs new file mode 100644 index 0000000000..79c846000d --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -0,0 +1,59 @@ +module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info +import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType') + +loopHoisting :: Module -> Module +loopHoisting md = mapT (const (umapL go)) md + where + go :: BinderList Binder -> Node -> Node + go bl node = case node of + NApp {} -> goApp bl h 0 args + where + (h, args) = unfoldApps node + _ -> + node + + goApp :: BinderList Binder -> Node -> Int -> [(Info, Node)] -> Node + goApp bl h argNum args = case args of + [] -> h + (info, arg) : args' -> case arg of + NLam {} -> + -- TODO: Check if `h` is recursive and its `argNum`th argument is + -- invariant + goApp bl (goLamApp bl info h arg) (argNum + 1) args' + _ -> goApp bl (mkApp info h arg) (argNum + 1) args' + + goLamApp :: BinderList Binder -> Info -> Node -> Node -> Node + goLamApp bl info h arg + | null subterms = mkApp info h arg + | otherwise = + mkLets' + (map (\node -> (computeNodeType' md bl node, node)) subterms) + (mkApp info h (reLambdasRev lams body')) + where + (lams, body) = unfoldLambdasRev arg + (subterms, body') = extractMaximalInvariantSubterms (length lams) body + + extractMaximalInvariantSubterms :: Int -> Node -> ([Node], Node) + extractMaximalInvariantSubterms bindersNum body = + first (map (removeInfo Info.kFreeVarsInfo)) + . second (removeInfo Info.kFreeVarsInfo) + . run + . runState [] + $ dmapNRM extract (Info.computeFreeVarsInfo body) + where + extract :: (Member (State [Node]) r) => Level -> Node -> Sem r Recur + extract n node + | not (isImmediate md node || isLambda node) + && isFullyApplied md node + && null fvars = do + k <- length <$> get @[Node] + modify' (node :) + return $ End (mkVar' (n + bindersNum + k)) + | otherwise = + return $ Recur node + where + fvars = filter (>= n + bindersNum) $ Info.getFreeVars node From d9d6a03cae508bce030d26220d11d2ad99d5fbf7 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 12:06:46 +0100 Subject: [PATCH 02/12] check recursively invariant args --- src/Juvix/Compiler/Core/Extra/Utils.hs | 39 +++++++++++++++++++ .../Transformation/Optimize/LoopHoisting.hs | 18 +++++---- .../Transformation/Optimize/SpecializeArgs.hs | 30 +------------- 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 349bbae432..8cf5a1ad5f 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -590,3 +590,42 @@ checkInfoTable tab = all isClosed (tab ^. identContext) && all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers) && all (isClosed . (^. constructorType)) (tab ^. infoConstructors) + +-- | Checks if an argument is passed without modification to direct recursive calls. +isArgRecursiveInvariant :: Module -> Symbol -> Int -> Bool +isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body + where + nodeSym = lookupIdentifierNode tab sym + (lams, body) = unfoldLambdas nodeSym + n = length lams + + go :: (Member (State Bool) r) => Level -> Node -> Sem r Recur + go lvl node = case node of + NApp {} -> + let (h, args) = unfoldApps' node + in case h of + NIdt Ident {..} + | _identSymbol == sym -> + let b = + argNum <= length args + && case args !! (argNum - 1) of + NVar Var {..} | _varIndex == lvl + n - argNum -> True + _ -> False + in do + modify' (&& b) + mapM_ (dmapNRM' (lvl, go)) args + return $ End node + _ -> return $ Recur node + NIdt Ident {..} + | _identSymbol == sym -> do + put False + return $ End node + _ -> return $ Recur node + +isDirectlyRecursive :: Module -> Symbol -> Bool +isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNode md sym) + where + go :: Node -> Bool + go = \case + NIdt Ident {..} -> _identSymbol == sym + _ -> False diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs index 79c846000d..9fb4875171 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -10,21 +10,23 @@ loopHoisting md = mapT (const (umapL go)) md where go :: BinderList Binder -> Node -> Node go bl node = case node of - NApp {} -> goApp bl h 0 args + NApp {} -> case h of + NIdt Ident {..} -> + goApp bl _identSymbol h 0 args + _ -> node where (h, args) = unfoldApps node _ -> node - goApp :: BinderList Binder -> Node -> Int -> [(Info, Node)] -> Node - goApp bl h argNum args = case args of + goApp :: BinderList Binder -> Symbol -> Node -> Int -> [(Info, Node)] -> Node + goApp bl sym h argNum args = case args of [] -> h (info, arg) : args' -> case arg of - NLam {} -> - -- TODO: Check if `h` is recursive and its `argNum`th argument is - -- invariant - goApp bl (goLamApp bl info h arg) (argNum + 1) args' - _ -> goApp bl (mkApp info h arg) (argNum + 1) args' + NLam {} + | isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym -> + goApp bl sym (goLamApp bl info h arg) (argNum + 1) args' + _ -> goApp bl sym (mkApp info h arg) (argNum + 1) args' goLamApp :: BinderList Binder -> Info -> Node -> Node -> Node goLamApp bl info h arg diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs index 8f192cd754..d5b5ccd2a7 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs @@ -48,36 +48,8 @@ isMarkedSpecializable md = \case _ -> False --- | Checks if an argument is passed without modification to recursive calls. isArgSpecializable :: Module -> Symbol -> Int -> Bool -isArgSpecializable tab sym argNum = run $ execState True $ dmapNRM go body - where - nodeSym = lookupIdentifierNode tab sym - (lams, body) = unfoldLambdas nodeSym - n = length lams - - go :: (Member (State Bool) r) => Level -> Node -> Sem r Recur - go lvl node = case node of - NApp {} -> - let (h, args) = unfoldApps' node - in case h of - NIdt Ident {..} - | _identSymbol == sym -> - let b = - argNum <= length args - && case args !! (argNum - 1) of - NVar Var {..} | _varIndex == lvl + n - argNum -> True - _ -> False - in do - modify' (&& b) - mapM_ (dmapNRM' (lvl, go)) args - return $ End node - _ -> return $ Recur node - NIdt Ident {..} - | _identSymbol == sym -> do - put False - return $ End node - _ -> return $ Recur node +isArgSpecializable = isArgRecursiveInvariant convertNode :: forall r. (Member InfoTableBuilder r) => Node -> Sem r Node convertNode = dmapLRM go From a2ff501742a0b81b6f5d27fd5ee53d5d14213d0e Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 12:45:23 +0100 Subject: [PATCH 03/12] transformation id --- src/Juvix/Compiler/Core/Data/TransformationId.hs | 2 ++ src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs | 3 +++ src/Juvix/Compiler/Core/Transformation.hs | 2 ++ 3 files changed, 7 insertions(+) diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index a951585dae..49045c9aec 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -32,6 +32,7 @@ data TransformationId | LetFolding | LambdaFolding | LetHoisting + | LoopHoisting | Inlining | MandatoryInlining | FoldTypeSynonyms @@ -109,6 +110,7 @@ instance TransformationId' TransformationId where LetFolding -> strLetFolding LambdaFolding -> strLambdaFolding LetHoisting -> strLetHoisting + LoopHoisting -> strLoopHoisting Inlining -> strInlining MandatoryInlining -> strMandatoryInlining FoldTypeSynonyms -> strFoldTypeSynonyms diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index 8b7f08824a..1e2746648b 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -5,6 +5,9 @@ import Juvix.Prelude strLetHoisting :: Text strLetHoisting = "let-hoisting" +strLoopHoisting :: Text +strLoopHoisting = "loop-hoisting" + strStoredPipeline :: Text strStoredPipeline = "pipeline-stored" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index 434ef8dba4..ffc31c9fee 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -43,6 +43,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnre import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding +import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec @@ -92,6 +93,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts LetFolding -> return . letFolding LambdaFolding -> return . lambdaFolding LetHoisting -> return . letHoisting + LoopHoisting -> return . loopHoisting Inlining -> inlining MandatoryInlining -> return . mandatoryInlining FoldTypeSynonyms -> return . foldTypeSynonyms From bed4bc8d5ef1e9a21d1f0bd6905c5e97256f78f9 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 13:04:22 +0100 Subject: [PATCH 04/12] pre-lifting optimization phase --- src/Juvix/Compiler/Core/Data/TransformationId.hs | 2 ++ .../Compiler/Core/Data/TransformationId/Strings.hs | 3 +++ src/Juvix/Compiler/Core/Transformation.hs | 2 ++ .../Core/Transformation/Optimize/Phase/PreLifting.hs | 10 ++++++++++ 4 files changed, 17 insertions(+) create mode 100644 src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 49045c9aec..60186f8806 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -48,6 +48,7 @@ data TransformationId | OptPhaseExec | OptPhaseVampIR | OptPhaseMain + | OptPhasePreLifting deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -126,6 +127,7 @@ instance TransformationId' TransformationId where OptPhaseExec -> strOptPhaseExec OptPhaseVampIR -> strOptPhaseVampIR OptPhaseMain -> strOptPhaseMain + OptPhasePreLifting -> strOptPhasePreLifting instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index 1e2746648b..f50ef8548a 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -145,3 +145,6 @@ strOptPhaseVampIR = "opt-phase-vampir" strOptPhaseMain :: Text strOptPhaseMain = "opt-phase-main" + +strOptPhasePreLifting :: Text +strOptPhasePreLifting = "opt-phase-pre-lifting" diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index ffc31c9fee..0c0c66ad9b 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -48,6 +48,7 @@ import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main +import Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting qualified as Phase.PreLifting import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs @@ -109,3 +110,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts OptPhaseExec -> Phase.Exec.optimize OptPhaseVampIR -> Phase.VampIR.optimize OptPhaseMain -> Phase.Main.optimize + OptPhasePreLifting -> Phase.PreLifting.optimize diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs new file mode 100644 index 0000000000..666413dd17 --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -0,0 +1,10 @@ +module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where + +import Juvix.Compiler.Core.Options +import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting + +optimize :: (Member (Reader CoreOptions) r) => Module -> Sem r Module +optimize = + withOptimizationLevel 1 $ + return . loopHoisting From 97656257336b5d9d5983d804e82fe92a2075bc49 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 17:09:48 +0100 Subject: [PATCH 05/12] bugfixes & test --- .../Compiler/Core/Data/TransformationId.hs | 2 +- src/Juvix/Compiler/Core/Extra/Utils.hs | 8 ++-- .../Transformation/Optimize/LoopHoisting.hs | 38 +++++++++++++------ .../Optimize/Phase/PreLifting.hs | 5 ++- .../Transformation/Optimize/SpecializeArgs.hs | 2 +- test/Compilation/Positive.hs | 7 +++- tests/Compilation/positive/out/test088.out | 1 + tests/Compilation/positive/test088.juvix | 14 +++++++ 8 files changed, 57 insertions(+), 20 deletions(-) create mode 100644 tests/Compilation/positive/out/test088.out create mode 100644 tests/Compilation/positive/test088.juvix diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index 60186f8806..43c15e78cc 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -79,7 +79,7 @@ toVampIRTransformations = toStrippedTransformations :: TransformationId -> [TransformationId] toStrippedTransformations checkId = - combineInfoTablesTransformations ++ [checkId, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames] + combineInfoTablesTransformations ++ [checkId, OptPhasePreLifting, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 8cf5a1ad5f..c502b2f1ab 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -591,7 +591,7 @@ checkInfoTable tab = && all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers) && all (isClosed . (^. constructorType)) (tab ^. infoConstructors) --- | Checks if an argument is passed without modification to direct recursive calls. +-- | Checks if `n`th argument (zero-based) is passed without modification to direct recursive calls. isArgRecursiveInvariant :: Module -> Symbol -> Int -> Bool isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body where @@ -607,9 +607,9 @@ isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body NIdt Ident {..} | _identSymbol == sym -> let b = - argNum <= length args - && case args !! (argNum - 1) of - NVar Var {..} | _varIndex == lvl + n - argNum -> True + argNum < length args + && case args !! argNum of + NVar Var {..} | _varIndex == lvl + n - argNum - 1 -> True _ -> False in do modify' (&& b) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs index 9fb4875171..f2f2c5ecc5 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -11,8 +11,7 @@ loopHoisting md = mapT (const (umapL go)) md go :: BinderList Binder -> Node -> Node go bl node = case node of NApp {} -> case h of - NIdt Ident {..} -> - goApp bl _identSymbol h 0 args + NIdt Ident {..} -> goApp bl _identSymbol h 0 args _ -> node where (h, args) = unfoldApps node @@ -25,19 +24,24 @@ loopHoisting md = mapT (const (umapL go)) md (info, arg) : args' -> case arg of NLam {} | isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym -> - goApp bl sym (goLamApp bl info h arg) (argNum + 1) args' + goLamApp bl sym info h arg (argNum + 1) args' _ -> goApp bl sym (mkApp info h arg) (argNum + 1) args' - goLamApp :: BinderList Binder -> Info -> Node -> Node -> Node - goLamApp bl info h arg - | null subterms = mkApp info h arg + goLamApp :: BinderList Binder -> Symbol -> Info -> Node -> Node -> Int -> [(Info, Node)] -> Node + goLamApp bl sym info h arg argNum args' + | null subterms = goApp bl sym (mkApp info h arg) argNum args' | otherwise = mkLets' - (map (\node -> (computeNodeType' md bl node, node)) subterms) - (mkApp info h (reLambdasRev lams body')) + (map (\node -> (computeNodeType' md bl node, node)) subterms') + ( adjustLetBoundVars + . shift n + $ (mkApps (mkApp info h (reLambdasRev lams body')) args') + ) where (lams, body) = unfoldLambdasRev arg (subterms, body') = extractMaximalInvariantSubterms (length lams) body + n = length subterms + subterms' = zipWith shift [0 ..] subterms extractMaximalInvariantSubterms :: Int -> Node -> ([Node], Node) extractMaximalInvariantSubterms bindersNum body = @@ -51,11 +55,21 @@ loopHoisting md = mapT (const (umapL go)) md extract n node | not (isImmediate md node || isLambda node) && isFullyApplied md node - && null fvars = do + && null boundVars = do k <- length <$> get @[Node] - modify' (node :) - return $ End (mkVar' (n + bindersNum + k)) + modify' ((shift (-(n + bindersNum)) node) :) + -- This variable is later adjusted to the correct index in `adjustLetBoundVars` + return $ End (mkVar' (-k - 1)) | otherwise = return $ Recur node where - fvars = filter (>= n + bindersNum) $ Info.getFreeVars node + boundVars = filter (< n + bindersNum) $ Info.getFreeVars node + + adjustLetBoundVars :: Node -> Node + adjustLetBoundVars = umapN adjust + where + adjust :: Level -> Node -> Node + adjust n node = case node of + NVar Var {..} + | _varIndex < 0 -> mkVar' (n - _varIndex - 1) + _ -> node diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs index 666413dd17..4819709f0e 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -2,9 +2,12 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.Optimize.LetFolding import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting optimize :: (Member (Reader CoreOptions) r) => Module -> Sem r Module optimize = withOptimizationLevel 1 $ - return . loopHoisting + return + . loopHoisting + . letFolding diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs index d5b5ccd2a7..db76fd102f 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs @@ -49,7 +49,7 @@ isMarkedSpecializable md = \case False isArgSpecializable :: Module -> Symbol -> Int -> Bool -isArgSpecializable = isArgRecursiveInvariant +isArgSpecializable md sym argNum = isArgRecursiveInvariant md sym (argNum - 1) convertNode :: forall r. (Member InfoTableBuilder r) => Node -> Sem r Node convertNode = dmapLRM go diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index 1fbdc9d630..9d45c506f1 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -510,5 +510,10 @@ tests = "Test087: Deriving Ord" $(mkRelDir ".") $(mkRelFile "test087.juvix") - $(mkRelFile "out/test087.out") + $(mkRelFile "out/test087.out"), + posTest + "Test088: Loop invariant code motion" + $(mkRelDir ".") + $(mkRelFile "test088.juvix") + $(mkRelFile "out/test088.out") ] diff --git a/tests/Compilation/positive/out/test088.out b/tests/Compilation/positive/out/test088.out new file mode 100644 index 0000000000..27ba77ddaf --- /dev/null +++ b/tests/Compilation/positive/out/test088.out @@ -0,0 +1 @@ +true diff --git a/tests/Compilation/positive/test088.juvix b/tests/Compilation/positive/test088.juvix new file mode 100644 index 0000000000..ed841f22a5 --- /dev/null +++ b/tests/Compilation/positive/test088.juvix @@ -0,0 +1,14 @@ +-- Loop-invariant code motion +module test088; + +import Stdlib.Prelude open; + +g (lst : List Nat) : Nat := + for (acc := 0) (x in lst) {acc + x + x * x + 176}; + +f (lst : List Nat) : Bool := + all (x in lst) { + x + g lst + g [0; 1] > 2000 + }; + +main : Bool := f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10]; From e9b73d44db93f606dcb01851eab78c020994ae99 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 18:38:06 +0100 Subject: [PATCH 06/12] fix duplication in argument specialization optimization --- src/Juvix/Compiler/Core/Extra/Utils.hs | 8 ++++++-- .../Transformation/Optimize/SpecializeArgs.hs | 20 +++++++++++++++++-- tests/Compilation/positive/out/test088.out | 2 +- tests/Compilation/positive/test088.juvix | 10 +++++++++- 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index c502b2f1ab..6e343c196a 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -155,6 +155,9 @@ isImmediate md = \case isImmediate' :: Node -> Bool isImmediate' = isImmediate emptyModule +isImmediateOrLambda :: Module -> Node -> Bool +isImmediateOrLambda md node = isImmediate md node || isLambda node + -- | True if the argument is fully evaluated first-order data isDataValue :: Node -> Bool isDataValue = \case @@ -167,7 +170,7 @@ isFullyApplied md node = case h of NIdt Ident {..} | Just ii <- lookupIdentifierInfo' md _identSymbol -> length args >= ii ^. identifierArgsNum - _ -> True + _ -> False where (h, args) = unfoldApps' node @@ -591,7 +594,8 @@ checkInfoTable tab = && all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers) && all (isClosed . (^. constructorType)) (tab ^. infoConstructors) --- | Checks if `n`th argument (zero-based) is passed without modification to direct recursive calls. +-- | Checks if the `n`th argument (zero-based) is passed without modification to +-- direct recursive calls. isArgRecursiveInvariant :: Module -> Symbol -> Int -> Bool isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body where diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs index db76fd102f..e61c9014c8 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs @@ -22,10 +22,24 @@ isSpecializable md node = NCtr Constr {..} -> case lookupConstructorInfo md _constrTag ^. constructorPragmas . pragmasSpecialise of Just (PragmaSpecialise False) -> False - _ -> True + _ -> + -- We need to avoid duplication of non-immediate expressions + all (isSpecializableConstrArg md) _constrArgs NApp {} -> - let (h, _) = unfoldApps' node + let (h, args) = unfoldApps' node in isSpecializable md h + -- We need to avoid duplication of non-immediate expressions + && all (isImmediateOrLambda md) args + _ -> False + +isSpecializableConstrArg :: Module -> Node -> Bool +isSpecializableConstrArg md node = + isImmediateOrLambda md node + || case node of + NCtr Constr {..} -> + case lookupConstructorInfo md _constrTag ^. constructorPragmas . pragmasSpecialise of + Just (PragmaSpecialise False) -> False + _ -> all (isSpecializableConstrArg md) _constrArgs _ -> False -- | Check for `h a1 .. an` where `h` is an identifier explicitly marked for @@ -48,6 +62,8 @@ isMarkedSpecializable md = \case _ -> False +-- | Checks if the `n`th argument (one-based) is passed without modification to +-- direct recursive calls. isArgSpecializable :: Module -> Symbol -> Int -> Bool isArgSpecializable md sym argNum = isArgRecursiveInvariant md sym (argNum - 1) diff --git a/tests/Compilation/positive/out/test088.out b/tests/Compilation/positive/out/test088.out index 27ba77ddaf..60d3b2f4a4 100644 --- a/tests/Compilation/positive/out/test088.out +++ b/tests/Compilation/positive/out/test088.out @@ -1 +1 @@ -true +15 diff --git a/tests/Compilation/positive/test088.juvix b/tests/Compilation/positive/test088.juvix index ed841f22a5..a8d2f5412b 100644 --- a/tests/Compilation/positive/test088.juvix +++ b/tests/Compilation/positive/test088.juvix @@ -11,4 +11,12 @@ f (lst : List Nat) : Bool := x + g lst + g [0; 1] > 2000 }; -main : Bool := f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10]; +main : Nat := + if + | f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] := + for (acc := 0) (x in [1; 2; 3; 4; 5]) { + if + | f [1; 2; 3; 4; 5; 6; 7; 8] := acc + x + | else := 0 + } + | else := 0; From a38fb16a3bb3f0c6212f6cdc6f863b5fbe310b51 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 19:25:26 +0100 Subject: [PATCH 07/12] volatility info --- .../Compiler/Core/Info/VolatilityInfo.hs | 21 ++++++++++++++ .../Transformation/Optimize/LetFolding.hs | 10 ++++--- .../Transformation/Optimize/LoopHoisting.hs | 28 +++++++++++++++---- 3 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 src/Juvix/Compiler/Core/Info/VolatilityInfo.hs diff --git a/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs b/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs new file mode 100644 index 0000000000..b3040f8865 --- /dev/null +++ b/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs @@ -0,0 +1,21 @@ +module Juvix.Compiler.Core.Info.VolatilityInfo where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info qualified as Info + +newtype VolatilityInfo = VolatilityInfo + { _infoIsVolatile :: Bool + } + +instance IsInfo VolatilityInfo + +kVolatilityInfo :: Key VolatilityInfo +kVolatilityInfo = Proxy + +makeLenses ''VolatilityInfo + +isVolatile :: Info -> Bool +isVolatile i = + case Info.lookup kVolatilityInfo i of + Just VolatilityInfo {..} -> _infoIsVolatile + Nothing -> False diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs index ab9fa789af..aea09b3d7c 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -17,6 +17,7 @@ import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Info.DebugOpsInfo as Info import Juvix.Compiler.Core.Info.FreeVarsInfo as Info +import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info import Juvix.Compiler.Core.Transformation.Base convertNode :: (Module -> BinderList Binder -> Node -> Bool) -> Module -> Node -> Node @@ -25,10 +26,11 @@ convertNode isFoldable md = rmapL go go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node go recur bl = \case NLet Let {..} - | ( isImmediate md (_letItem ^. letItemValue) - || Info.freeVarOccurrences 0 _letBody <= 1 - || isFoldable md bl (_letItem ^. letItemValue) - ) + | not (Info.isVolatile _letInfo) + && ( isImmediate md (_letItem ^. letItemValue) + || Info.freeVarOccurrences 0 _letBody <= 1 + || isFoldable md bl (_letItem ^. letItemValue) + ) && not (Info.hasDebugOps _letBody) -> go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody where diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs index f2f2c5ecc5..0f846a6e14 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -1,7 +1,9 @@ module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info qualified as Info import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info +import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType') @@ -31,12 +33,13 @@ loopHoisting md = mapT (const (umapL go)) md goLamApp bl sym info h arg argNum args' | null subterms = goApp bl sym (mkApp info h arg) argNum args' | otherwise = - mkLets' - (map (\node -> (computeNodeType' md bl node, node)) subterms') - ( adjustLetBoundVars - . shift n - $ (mkApps (mkApp info h (reLambdasRev lams body')) args') - ) + setLetsVolatile n $ + mkLets' + (map (\node -> (computeNodeType' md bl node, node)) subterms') + ( adjustLetBoundVars + . shift n + $ (mkApps (mkApp info h (reLambdasRev lams body')) args') + ) where (lams, body) = unfoldLambdasRev arg (subterms, body') = extractMaximalInvariantSubterms (length lams) body @@ -73,3 +76,16 @@ loopHoisting md = mapT (const (umapL go)) md NVar Var {..} | _varIndex < 0 -> mkVar' (n - _varIndex - 1) _ -> node + + setLetsVolatile :: Int -> Node -> Node + setLetsVolatile n + | n == 0 = id + | otherwise = \case + NLet Let {..} -> + NLet + Let + { _letInfo = Info.insert (Info.VolatilityInfo True) _letInfo, + _letBody = setLetsVolatile (n - 1) _letBody, + _letItem + } + node -> node From d11af02632dbe60c9b612dc497ce2636905b88a7 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 28 Nov 2024 19:41:23 +0100 Subject: [PATCH 08/12] TODOs --- .../Compiler/Core/Transformation/Optimize/LoopHoisting.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs index 0f846a6e14..a08766d601 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -13,9 +13,11 @@ loopHoisting md = mapT (const (umapL go)) md go :: BinderList Binder -> Node -> Node go bl node = case node of NApp {} -> case h of + -- TODO: variables NIdt Ident {..} -> goApp bl _identSymbol h 0 args _ -> node where + -- TODO: consider only fully applied (h, args) = unfoldApps node _ -> node @@ -57,7 +59,7 @@ loopHoisting md = mapT (const (umapL go)) md extract :: (Member (State [Node]) r) => Level -> Node -> Sem r Recur extract n node | not (isImmediate md node || isLambda node) - && isFullyApplied md node + && isFullyApplied md node -- TODO: variables && null boundVars = do k <- length <$> get @[Node] modify' ((shift (-(n + bindersNum)) node) :) From 9731444a8439bc9cdac763326549e9d46383a24a Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 29 Nov 2024 12:24:15 +0100 Subject: [PATCH 09/12] improve optimizations --- juvix-stdlib | 2 +- src/Juvix/Compiler/Core/Extra/Utils.hs | 15 +++++--- .../Transformation/Optimize/LoopHoisting.hs | 34 ++++++++++++------- .../Transformation/Optimize/Phase/Main.hs | 2 +- .../Optimize/Phase/PreLifting.hs | 21 ++++++++++-- 5 files changed, 53 insertions(+), 21 deletions(-) diff --git a/juvix-stdlib b/juvix-stdlib index fde9ac2353..825899d931 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit fde9ac23534fe1c0ba3f69714233dbd1d3934a9c +Subproject commit 825899d9314c6ff05da207bb82689bdb40d7e782 diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 6e343c196a..c377000459 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -165,12 +165,19 @@ isDataValue = \case NCtr Constr {..} -> all isDataValue _constrArgs _ -> False -isFullyApplied :: Module -> Node -> Bool -isFullyApplied md node = case h of +isFullyApplied :: Module -> BinderList Binder -> Node -> Bool +isFullyApplied md bl node = case h of NIdt Ident {..} | Just ii <- lookupIdentifierInfo' md _identSymbol -> - length args >= ii ^. identifierArgsNum - _ -> False + length args == ii ^. identifierArgsNum + NVar Var {..} -> + case BL.lookupMay _varIndex bl of + Just Binder {..} -> + length args == length (typeArgs _binderType) + Nothing -> + False + _ -> + False where (h, args) = unfoldApps' node diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs index a08766d601..9f08f837de 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -1,5 +1,6 @@ module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where +import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Info qualified as Info import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info @@ -13,11 +14,12 @@ loopHoisting md = mapT (const (umapL go)) md go :: BinderList Binder -> Node -> Node go bl node = case node of NApp {} -> case h of - -- TODO: variables - NIdt Ident {..} -> goApp bl _identSymbol h 0 args + NIdt Ident {..} + | Just ii <- lookupIdentifierInfo' md _identSymbol, + length args == ii ^. identifierArgsNum -> + goApp bl _identSymbol h 0 args _ -> node where - -- TODO: consider only fully applied (h, args) = unfoldApps node _ -> node @@ -27,10 +29,14 @@ loopHoisting md = mapT (const (umapL go)) md [] -> h (info, arg) : args' -> case arg of NLam {} - | isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym -> + | isHoistable sym argNum -> goLamApp bl sym info h arg (argNum + 1) args' _ -> goApp bl sym (mkApp info h arg) (argNum + 1) args' + isHoistable :: Symbol -> Int -> Bool + isHoistable sym argNum = + isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym + goLamApp :: BinderList Binder -> Symbol -> Info -> Node -> Node -> Int -> [(Info, Node)] -> Node goLamApp bl sym info h arg argNum args' | null subterms = goApp bl sym (mkApp info h arg) argNum args' @@ -44,31 +50,33 @@ loopHoisting md = mapT (const (umapL go)) md ) where (lams, body) = unfoldLambdasRev arg - (subterms, body') = extractMaximalInvariantSubterms (length lams) body + bl' = BL.prepend (map (^. lambdaLhsBinder) lams) bl + (subterms, body') = extractMaximalInvariantSubterms (length bl) bl' body n = length subterms subterms' = zipWith shift [0 ..] subterms - extractMaximalInvariantSubterms :: Int -> Node -> ([Node], Node) - extractMaximalInvariantSubterms bindersNum body = + extractMaximalInvariantSubterms :: Int -> BinderList Binder -> Node -> ([Node], Node) + extractMaximalInvariantSubterms initialBindersNum bl0 body = first (map (removeInfo Info.kFreeVarsInfo)) . second (removeInfo Info.kFreeVarsInfo) . run . runState [] - $ dmapNRM extract (Info.computeFreeVarsInfo body) + $ dmapLRM' (bl0, extract) (Info.computeFreeVarsInfo body) where - extract :: (Member (State [Node]) r) => Level -> Node -> Sem r Recur - extract n node + extract :: (Member (State [Node]) r) => BinderList Binder -> Node -> Sem r Recur + extract bl node | not (isImmediate md node || isLambda node) - && isFullyApplied md node -- TODO: variables + && isFullyApplied md bl node && null boundVars = do k <- length <$> get @[Node] - modify' ((shift (-(n + bindersNum)) node) :) + modify' ((shift (-n) node) :) -- This variable is later adjusted to the correct index in `adjustLetBoundVars` return $ End (mkVar' (-k - 1)) | otherwise = return $ Recur node where - boundVars = filter (< n + bindersNum) $ Info.getFreeVars node + boundVars = filter (< n) $ Info.getFreeVars node + n = length bl - initialBindersNum adjustLetBoundVars :: Node -> Node adjustLetBoundVars = umapN adjust diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 9a4ee89a24..5e5aed33d7 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -19,7 +19,7 @@ optimize' :: CoreOptions -> Module -> Module optimize' opts@CoreOptions {..} md = filterUnreachable . compose - (6 * _optOptimizationLevel) + (4 * _optOptimizationLevel) ( doConstantFolding . doSimplification 1 . specializeArgs diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs index 4819709f0e..0e73e1a7db 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -1,13 +1,30 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where +import Juvix.Compiler.Core.Data.IdentDependencyInfo import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.Inlining +import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting optimize :: (Member (Reader CoreOptions) r) => Module -> Sem r Module -optimize = - withOptimizationLevel 1 $ +optimize md = do + CoreOptions {..} <- ask + withOptimizationLevel' md 1 $ return . loopHoisting . letFolding + . lambdaFolding + . letFolding + . caseFolding + . compose + 2 + ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) + . lambdaFolding + . inlining' _optInliningDepth nonRecSyms + ) + . letFolding + where + nonRecSyms = nonRecursiveIdents md From ceed253eb5c3ab9d4b956086d538c4134ec78dba Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 29 Nov 2024 12:27:51 +0100 Subject: [PATCH 10/12] update stdlib --- juvix-stdlib | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/juvix-stdlib b/juvix-stdlib index 825899d931..0c456725a2 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit 825899d9314c6ff05da207bb82689bdb40d7e782 +Subproject commit 0c456725a23648606f97aebf8f74a9e2a73e90b6 From b33e5c2be31d4df5a330f16a626ff8752920c1b9 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 29 Nov 2024 13:08:28 +0100 Subject: [PATCH 11/12] rename test --- test/Compilation/Positive.hs | 6 +++--- tests/Compilation/positive/out/test088.out | 1 - tests/Compilation/positive/test088.juvix | 22 ---------------------- 3 files changed, 3 insertions(+), 26 deletions(-) delete mode 100644 tests/Compilation/positive/out/test088.out delete mode 100644 tests/Compilation/positive/test088.juvix diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index 9d45c506f1..55ff7d7941 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -512,8 +512,8 @@ tests = $(mkRelFile "test087.juvix") $(mkRelFile "out/test087.out"), posTest - "Test088: Loop invariant code motion" + "Test089: Loop invariant code motion" $(mkRelDir ".") - $(mkRelFile "test088.juvix") - $(mkRelFile "out/test088.out") + $(mkRelFile "test089.juvix") + $(mkRelFile "out/test089.out") ] diff --git a/tests/Compilation/positive/out/test088.out b/tests/Compilation/positive/out/test088.out deleted file mode 100644 index 60d3b2f4a4..0000000000 --- a/tests/Compilation/positive/out/test088.out +++ /dev/null @@ -1 +0,0 @@ -15 diff --git a/tests/Compilation/positive/test088.juvix b/tests/Compilation/positive/test088.juvix deleted file mode 100644 index a8d2f5412b..0000000000 --- a/tests/Compilation/positive/test088.juvix +++ /dev/null @@ -1,22 +0,0 @@ --- Loop-invariant code motion -module test088; - -import Stdlib.Prelude open; - -g (lst : List Nat) : Nat := - for (acc := 0) (x in lst) {acc + x + x * x + 176}; - -f (lst : List Nat) : Bool := - all (x in lst) { - x + g lst + g [0; 1] > 2000 - }; - -main : Nat := - if - | f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] := - for (acc := 0) (x in [1; 2; 3; 4; 5]) { - if - | f [1; 2; 3; 4; 5; 6; 7; 8] := acc + x - | else := 0 - } - | else := 0; From 37258d13a0a1ecd28c1c9c5374044314d0a11a21 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 29 Nov 2024 13:09:40 +0100 Subject: [PATCH 12/12] add test --- tests/Compilation/positive/out/test089.out | 1 + tests/Compilation/positive/test089.juvix | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 tests/Compilation/positive/out/test089.out create mode 100644 tests/Compilation/positive/test089.juvix diff --git a/tests/Compilation/positive/out/test089.out b/tests/Compilation/positive/out/test089.out new file mode 100644 index 0000000000..60d3b2f4a4 --- /dev/null +++ b/tests/Compilation/positive/out/test089.out @@ -0,0 +1 @@ +15 diff --git a/tests/Compilation/positive/test089.juvix b/tests/Compilation/positive/test089.juvix new file mode 100644 index 0000000000..44e1bc81a6 --- /dev/null +++ b/tests/Compilation/positive/test089.juvix @@ -0,0 +1,22 @@ +-- Loop-invariant code motion +module test089; + +import Stdlib.Prelude open; + +g (lst : List Nat) : Nat := + for (acc := 0) (x in lst) {acc + x + x * x + 176}; + +f (lst : List Nat) : Bool := + all (x in lst) { + x + g lst + g [0; 1] > 2000 + }; + +main : Nat := + if + | f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] := + for (acc := 0) (x in [1; 2; 3; 4; 5]) { + if + | f [1; 2; 3; 4; 5; 6; 7; 8] := acc + x + | else := 0 + } + | else := 0;