diff --git a/juvix-stdlib b/juvix-stdlib index fde9ac2353..0c456725a2 160000 --- a/juvix-stdlib +++ b/juvix-stdlib @@ -1 +1 @@ -Subproject commit fde9ac23534fe1c0ba3f69714233dbd1d3934a9c +Subproject commit 0c456725a23648606f97aebf8f74a9e2a73e90b6 diff --git a/src/Juvix/Compiler/Core/Data/TransformationId.hs b/src/Juvix/Compiler/Core/Data/TransformationId.hs index a951585dae..43c15e78cc 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId.hs @@ -32,6 +32,7 @@ data TransformationId | LetFolding | LambdaFolding | LetHoisting + | LoopHoisting | Inlining | MandatoryInlining | FoldTypeSynonyms @@ -47,6 +48,7 @@ data TransformationId | OptPhaseExec | OptPhaseVampIR | OptPhaseMain + | OptPhasePreLifting deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -77,7 +79,7 @@ toVampIRTransformations = toStrippedTransformations :: TransformationId -> [TransformationId] toStrippedTransformations checkId = - combineInfoTablesTransformations ++ [checkId, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames] + combineInfoTablesTransformations ++ [checkId, OptPhasePreLifting, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text @@ -109,6 +111,7 @@ instance TransformationId' TransformationId where LetFolding -> strLetFolding LambdaFolding -> strLambdaFolding LetHoisting -> strLetHoisting + LoopHoisting -> strLoopHoisting Inlining -> strInlining MandatoryInlining -> strMandatoryInlining FoldTypeSynonyms -> strFoldTypeSynonyms @@ -124,6 +127,7 @@ instance TransformationId' TransformationId where OptPhaseExec -> strOptPhaseExec OptPhaseVampIR -> strOptPhaseVampIR OptPhaseMain -> strOptPhaseMain + OptPhasePreLifting -> strOptPhasePreLifting instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs index 8b7f08824a..f50ef8548a 100644 --- a/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs @@ -5,6 +5,9 @@ import Juvix.Prelude strLetHoisting :: Text strLetHoisting = "let-hoisting" +strLoopHoisting :: Text +strLoopHoisting = "loop-hoisting" + strStoredPipeline :: Text strStoredPipeline = "pipeline-stored" @@ -142,3 +145,6 @@ strOptPhaseVampIR = "opt-phase-vampir" strOptPhaseMain :: Text strOptPhaseMain = "opt-phase-main" + +strOptPhasePreLifting :: Text +strOptPhasePreLifting = "opt-phase-pre-lifting" diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index f5a1b6b775..5d2a798b81 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -59,12 +59,18 @@ mkLambda i bi b = NLam (Lambda i bi b) mkLambda' :: Type -> Node -> Node mkLambda' ty = mkLambda Info.empty (mkBinder' ty) +mkLambda'' :: Binder -> Node -> Node +mkLambda'' = mkLambda Info.empty + mkLambdas :: [Info] -> [Binder] -> Node -> Node mkLambdas is bs n = foldl' (flip (uncurry mkLambda)) n (reverse (zipExact is bs)) mkLambdas' :: [Type] -> Node -> Node mkLambdas' tys n = foldl' (flip mkLambda') n (reverse tys) +mkLambdas'' :: [Binder] -> Node -> Node +mkLambdas'' bs n = foldl' (flip mkLambda'') n (reverse bs) + mkLetItem :: Text -> Type -> Node -> LetItem mkLetItem name ty = LetItem (mkBinder name ty) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index 77d5d836d0..c377000459 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -155,6 +155,9 @@ isImmediate md = \case isImmediate' :: Node -> Bool isImmediate' = isImmediate emptyModule +isImmediateOrLambda :: Module -> Node -> Bool +isImmediateOrLambda md node = isImmediate md node || isLambda node + -- | True if the argument is fully evaluated first-order data isDataValue :: Node -> Bool isDataValue = \case @@ -162,11 +165,32 @@ isDataValue = \case NCtr Constr {..} -> all isDataValue _constrArgs _ -> False +isFullyApplied :: Module -> BinderList Binder -> Node -> Bool +isFullyApplied md bl node = case h of + NIdt Ident {..} + | Just ii <- lookupIdentifierInfo' md _identSymbol -> + length args == ii ^. identifierArgsNum + NVar Var {..} -> + case BL.lookupMay _varIndex bl of + Just Binder {..} -> + length args == length (typeArgs _binderType) + Nothing -> + False + _ -> + False + where + (h, args) = unfoldApps' node + isFailNode :: Node -> Bool isFailNode = \case NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True _ -> False +isLambda :: Node -> Bool +isLambda = \case + NLam {} -> True + _ -> False + isTrueConstr :: Node -> Bool isTrueConstr = \case NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True @@ -576,3 +600,43 @@ checkInfoTable tab = all isClosed (tab ^. identContext) && all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers) && all (isClosed . (^. constructorType)) (tab ^. infoConstructors) + +-- | Checks if the `n`th argument (zero-based) is passed without modification to +-- direct recursive calls. +isArgRecursiveInvariant :: Module -> Symbol -> Int -> Bool +isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body + where + nodeSym = lookupIdentifierNode tab sym + (lams, body) = unfoldLambdas nodeSym + n = length lams + + go :: (Member (State Bool) r) => Level -> Node -> Sem r Recur + go lvl node = case node of + NApp {} -> + let (h, args) = unfoldApps' node + in case h of + NIdt Ident {..} + | _identSymbol == sym -> + let b = + argNum < length args + && case args !! argNum of + NVar Var {..} | _varIndex == lvl + n - argNum - 1 -> True + _ -> False + in do + modify' (&& b) + mapM_ (dmapNRM' (lvl, go)) args + return $ End node + _ -> return $ Recur node + NIdt Ident {..} + | _identSymbol == sym -> do + put False + return $ End node + _ -> return $ Recur node + +isDirectlyRecursive :: Module -> Symbol -> Bool +isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNode md sym) + where + go :: Node -> Bool + go = \case + NIdt Ident {..} -> _identSymbol == sym + _ -> False diff --git a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs index 312c6cfe61..07c3e13fcb 100644 --- a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs +++ b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs @@ -38,7 +38,7 @@ computeFreeVarsInfo' lambdaMultiplier = umap go fvi = FreeVarsInfo . fmap (* lambdaMultiplier) - $ getFreeVars 1 _lambdaBody + $ getFreeVars' 1 _lambdaBody _ -> modifyInfo (Info.insert fvi) node where @@ -47,13 +47,13 @@ computeFreeVarsInfo' lambdaMultiplier = umap go foldr ( \NodeChild {..} acc -> Map.unionWith (+) acc $ - getFreeVars _childBindersNum _childNode + getFreeVars' _childBindersNum _childNode ) mempty (children node) - getFreeVars :: Int -> Node -> Map Index Int - getFreeVars bindersNum node = + getFreeVars' :: Int -> Node -> Map Index Int + getFreeVars' bindersNum node = Map.mapKeysMonotonic (\idx -> idx - bindersNum) . Map.filterWithKey (\idx _ -> idx >= bindersNum) $ getFreeVarsInfo node ^. infoFreeVars @@ -61,6 +61,9 @@ computeFreeVarsInfo' lambdaMultiplier = umap go getFreeVarsInfo :: Node -> FreeVarsInfo getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo +getFreeVars :: Node -> [Index] +getFreeVars = Map.keys . Map.filter (> 0) . (^. infoFreeVars) . getFreeVarsInfo + freeVarOccurrences :: Index -> Node -> Int freeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getFreeVarsInfo n ^. infoFreeVars)) diff --git a/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs b/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs new file mode 100644 index 0000000000..b3040f8865 --- /dev/null +++ b/src/Juvix/Compiler/Core/Info/VolatilityInfo.hs @@ -0,0 +1,21 @@ +module Juvix.Compiler.Core.Info.VolatilityInfo where + +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info qualified as Info + +newtype VolatilityInfo = VolatilityInfo + { _infoIsVolatile :: Bool + } + +instance IsInfo VolatilityInfo + +kVolatilityInfo :: Key VolatilityInfo +kVolatilityInfo = Proxy + +makeLenses ''VolatilityInfo + +isVolatile :: Info -> Bool +isVolatile i = + case Info.lookup kVolatilityInfo i of + Just VolatilityInfo {..} -> _infoIsVolatile + Nothing -> False diff --git a/src/Juvix/Compiler/Core/Transformation.hs b/src/Juvix/Compiler/Core/Transformation.hs index 434ef8dba4..0c0c66ad9b 100644 --- a/src/Juvix/Compiler/Core/Transformation.hs +++ b/src/Juvix/Compiler/Core/Transformation.hs @@ -43,10 +43,12 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnre import Juvix.Compiler.Core.Transformation.Optimize.Inlining import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding import Juvix.Compiler.Core.Transformation.Optimize.LetFolding +import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main +import Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting qualified as Phase.PreLifting import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons) import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs @@ -92,6 +94,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts LetFolding -> return . letFolding LambdaFolding -> return . lambdaFolding LetHoisting -> return . letHoisting + LoopHoisting -> return . loopHoisting Inlining -> inlining MandatoryInlining -> return . mandatoryInlining FoldTypeSynonyms -> return . foldTypeSynonyms @@ -107,3 +110,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts OptPhaseExec -> Phase.Exec.optimize OptPhaseVampIR -> Phase.VampIR.optimize OptPhaseMain -> Phase.Main.optimize + OptPhasePreLifting -> Phase.PreLifting.optimize diff --git a/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs b/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs index 4e41e5014f..e5972ccb50 100644 --- a/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs +++ b/src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs @@ -5,6 +5,14 @@ import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Info.TypeInfo qualified as Info import Juvix.Compiler.Core.Transformation.Base +computeNodeType' :: Module -> BinderList Binder -> Node -> Type +computeNodeType' md bl node = rePis argtys' ty' + where + ty = computeNodeType md (mkLambdas'' (reverse (toList bl)) node) + (argtys, ty') = unfoldPi ty + argtys' = drop (length bl) argtys + +-- | Computes the type of a closed well-typed node. computeNodeType :: Module -> Node -> Type computeNodeType md = Info.getNodeType . computeNodeTypeInfo md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs index ab9fa789af..aea09b3d7c 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -17,6 +17,7 @@ import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Extra import Juvix.Compiler.Core.Info.DebugOpsInfo as Info import Juvix.Compiler.Core.Info.FreeVarsInfo as Info +import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info import Juvix.Compiler.Core.Transformation.Base convertNode :: (Module -> BinderList Binder -> Node -> Bool) -> Module -> Node -> Node @@ -25,10 +26,11 @@ convertNode isFoldable md = rmapL go go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node go recur bl = \case NLet Let {..} - | ( isImmediate md (_letItem ^. letItemValue) - || Info.freeVarOccurrences 0 _letBody <= 1 - || isFoldable md bl (_letItem ^. letItemValue) - ) + | not (Info.isVolatile _letInfo) + && ( isImmediate md (_letItem ^. letItemValue) + || Info.freeVarOccurrences 0 _letBody <= 1 + || isFoldable md bl (_letItem ^. letItemValue) + ) && not (Info.hasDebugOps _letBody) -> go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody where diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs new file mode 100644 index 0000000000..9f08f837de --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs @@ -0,0 +1,101 @@ +module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where + +import Juvix.Compiler.Core.Data.BinderList qualified as BL +import Juvix.Compiler.Core.Extra +import Juvix.Compiler.Core.Info qualified as Info +import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info +import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info +import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType') + +loopHoisting :: Module -> Module +loopHoisting md = mapT (const (umapL go)) md + where + go :: BinderList Binder -> Node -> Node + go bl node = case node of + NApp {} -> case h of + NIdt Ident {..} + | Just ii <- lookupIdentifierInfo' md _identSymbol, + length args == ii ^. identifierArgsNum -> + goApp bl _identSymbol h 0 args + _ -> node + where + (h, args) = unfoldApps node + _ -> + node + + goApp :: BinderList Binder -> Symbol -> Node -> Int -> [(Info, Node)] -> Node + goApp bl sym h argNum args = case args of + [] -> h + (info, arg) : args' -> case arg of + NLam {} + | isHoistable sym argNum -> + goLamApp bl sym info h arg (argNum + 1) args' + _ -> goApp bl sym (mkApp info h arg) (argNum + 1) args' + + isHoistable :: Symbol -> Int -> Bool + isHoistable sym argNum = + isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym + + goLamApp :: BinderList Binder -> Symbol -> Info -> Node -> Node -> Int -> [(Info, Node)] -> Node + goLamApp bl sym info h arg argNum args' + | null subterms = goApp bl sym (mkApp info h arg) argNum args' + | otherwise = + setLetsVolatile n $ + mkLets' + (map (\node -> (computeNodeType' md bl node, node)) subterms') + ( adjustLetBoundVars + . shift n + $ (mkApps (mkApp info h (reLambdasRev lams body')) args') + ) + where + (lams, body) = unfoldLambdasRev arg + bl' = BL.prepend (map (^. lambdaLhsBinder) lams) bl + (subterms, body') = extractMaximalInvariantSubterms (length bl) bl' body + n = length subterms + subterms' = zipWith shift [0 ..] subterms + + extractMaximalInvariantSubterms :: Int -> BinderList Binder -> Node -> ([Node], Node) + extractMaximalInvariantSubterms initialBindersNum bl0 body = + first (map (removeInfo Info.kFreeVarsInfo)) + . second (removeInfo Info.kFreeVarsInfo) + . run + . runState [] + $ dmapLRM' (bl0, extract) (Info.computeFreeVarsInfo body) + where + extract :: (Member (State [Node]) r) => BinderList Binder -> Node -> Sem r Recur + extract bl node + | not (isImmediate md node || isLambda node) + && isFullyApplied md bl node + && null boundVars = do + k <- length <$> get @[Node] + modify' ((shift (-n) node) :) + -- This variable is later adjusted to the correct index in `adjustLetBoundVars` + return $ End (mkVar' (-k - 1)) + | otherwise = + return $ Recur node + where + boundVars = filter (< n) $ Info.getFreeVars node + n = length bl - initialBindersNum + + adjustLetBoundVars :: Node -> Node + adjustLetBoundVars = umapN adjust + where + adjust :: Level -> Node -> Node + adjust n node = case node of + NVar Var {..} + | _varIndex < 0 -> mkVar' (n - _varIndex - 1) + _ -> node + + setLetsVolatile :: Int -> Node -> Node + setLetsVolatile n + | n == 0 = id + | otherwise = \case + NLet Let {..} -> + NLet + Let + { _letInfo = Info.insert (Info.VolatilityInfo True) _letInfo, + _letBody = setLetsVolatile (n - 1) _letBody, + _letItem + } + node -> node diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 9a4ee89a24..5e5aed33d7 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -19,7 +19,7 @@ optimize' :: CoreOptions -> Module -> Module optimize' opts@CoreOptions {..} md = filterUnreachable . compose - (6 * _optOptimizationLevel) + (4 * _optOptimizationLevel) ( doConstantFolding . doSimplification 1 . specializeArgs diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs new file mode 100644 index 0000000000..0e73e1a7db --- /dev/null +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -0,0 +1,30 @@ +module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where + +import Juvix.Compiler.Core.Data.IdentDependencyInfo +import Juvix.Compiler.Core.Options +import Juvix.Compiler.Core.Transformation.Base +import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding +import Juvix.Compiler.Core.Transformation.Optimize.Inlining +import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding +import Juvix.Compiler.Core.Transformation.Optimize.LetFolding +import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting + +optimize :: (Member (Reader CoreOptions) r) => Module -> Sem r Module +optimize md = do + CoreOptions {..} <- ask + withOptimizationLevel' md 1 $ + return + . loopHoisting + . letFolding + . lambdaFolding + . letFolding + . caseFolding + . compose + 2 + ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) + . lambdaFolding + . inlining' _optInliningDepth nonRecSyms + ) + . letFolding + where + nonRecSyms = nonRecursiveIdents md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs index 8f192cd754..e61c9014c8 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/SpecializeArgs.hs @@ -22,10 +22,24 @@ isSpecializable md node = NCtr Constr {..} -> case lookupConstructorInfo md _constrTag ^. constructorPragmas . pragmasSpecialise of Just (PragmaSpecialise False) -> False - _ -> True + _ -> + -- We need to avoid duplication of non-immediate expressions + all (isSpecializableConstrArg md) _constrArgs NApp {} -> - let (h, _) = unfoldApps' node + let (h, args) = unfoldApps' node in isSpecializable md h + -- We need to avoid duplication of non-immediate expressions + && all (isImmediateOrLambda md) args + _ -> False + +isSpecializableConstrArg :: Module -> Node -> Bool +isSpecializableConstrArg md node = + isImmediateOrLambda md node + || case node of + NCtr Constr {..} -> + case lookupConstructorInfo md _constrTag ^. constructorPragmas . pragmasSpecialise of + Just (PragmaSpecialise False) -> False + _ -> all (isSpecializableConstrArg md) _constrArgs _ -> False -- | Check for `h a1 .. an` where `h` is an identifier explicitly marked for @@ -48,36 +62,10 @@ isMarkedSpecializable md = \case _ -> False --- | Checks if an argument is passed without modification to recursive calls. +-- | Checks if the `n`th argument (one-based) is passed without modification to +-- direct recursive calls. isArgSpecializable :: Module -> Symbol -> Int -> Bool -isArgSpecializable tab sym argNum = run $ execState True $ dmapNRM go body - where - nodeSym = lookupIdentifierNode tab sym - (lams, body) = unfoldLambdas nodeSym - n = length lams - - go :: (Member (State Bool) r) => Level -> Node -> Sem r Recur - go lvl node = case node of - NApp {} -> - let (h, args) = unfoldApps' node - in case h of - NIdt Ident {..} - | _identSymbol == sym -> - let b = - argNum <= length args - && case args !! (argNum - 1) of - NVar Var {..} | _varIndex == lvl + n - argNum -> True - _ -> False - in do - modify' (&& b) - mapM_ (dmapNRM' (lvl, go)) args - return $ End node - _ -> return $ Recur node - NIdt Ident {..} - | _identSymbol == sym -> do - put False - return $ End node - _ -> return $ Recur node +isArgSpecializable md sym argNum = isArgRecursiveInvariant md sym (argNum - 1) convertNode :: forall r. (Member InfoTableBuilder r) => Node -> Sem r Node convertNode = dmapLRM go diff --git a/test/Compilation/Positive.hs b/test/Compilation/Positive.hs index cdfd8ad9b6..c7c5a26a8e 100644 --- a/test/Compilation/Positive.hs +++ b/test/Compilation/Positive.hs @@ -515,5 +515,10 @@ tests = "Test088: Record update pun" $(mkRelDir ".") $(mkRelFile "test088.juvix") - $(mkRelFile "out/test088.out") + $(mkRelFile "out/test088.out"), + posTest + "Test089: Loop invariant code motion" + $(mkRelDir ".") + $(mkRelFile "test089.juvix") + $(mkRelFile "out/test089.out") ] diff --git a/tests/Compilation/positive/out/test089.out b/tests/Compilation/positive/out/test089.out new file mode 100644 index 0000000000..60d3b2f4a4 --- /dev/null +++ b/tests/Compilation/positive/out/test089.out @@ -0,0 +1 @@ +15 diff --git a/tests/Compilation/positive/test089.juvix b/tests/Compilation/positive/test089.juvix new file mode 100644 index 0000000000..44e1bc81a6 --- /dev/null +++ b/tests/Compilation/positive/test089.juvix @@ -0,0 +1,22 @@ +-- Loop-invariant code motion +module test089; + +import Stdlib.Prelude open; + +g (lst : List Nat) : Nat := + for (acc := 0) (x in lst) {acc + x + x * x + 176}; + +f (lst : List Nat) : Bool := + all (x in lst) { + x + g lst + g [0; 1] > 2000 + }; + +main : Nat := + if + | f [1; 2; 3; 4; 5; 6; 7; 8; 9; 10] := + for (acc := 0) (x in [1; 2; 3; 4; 5]) { + if + | f [1; 2; 3; 4; 5; 6; 7; 8] := acc + x + | else := 0 + } + | else := 0;