Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline non-recursive functions with only one call site #3204

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -576,3 +576,32 @@ checkInfoTable tab =
all isClosed (tab ^. identContext)
&& all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers)
&& all (isClosed . (^. constructorType)) (tab ^. infoConstructors)

-- Returns a map from symbols to their number of occurrences in the given node.
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
getSymbolsMap md = gather go mempty
where
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
go acc = \case
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
NIdt Ident {..} -> mapInc _identSymbol acc
NCase Case {..} -> mapInc _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' md _constrTag ->
mapInc (ci ^. constructorInductive) acc
_ -> acc

mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
mapInc k = HashMap.insertWith (+) k 1

getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
getTableSymbolsMap tab =
foldr
(HashMap.unionWith (+))
mempty
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
where
md = emptyModule {_moduleInfoTable = tab}

getModuleSymbolsMap :: Module -> HashMap Symbol Int
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable
21 changes: 13 additions & 8 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo
Expand All @@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
_ ->
False

convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms md = dmapL go
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
Expand All @@ -38,8 +39,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
node
_
| HashSet.member _identSymbol nonRecSyms
&& isInlineableLambda inlineDepth md bl def
&& length args >= argsNum ->
&& length args >= argsNum
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isInlineableLambda inlineDepth md bl def
) ->
mkApps def args
_ ->
node
Expand All @@ -58,7 +61,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
Just InlineNever -> node
_
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isImmediate md def
) ->
def
| otherwise ->
node
Expand Down Expand Up @@ -98,10 +103,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
where
(lamsNum, body) = unfoldLambdas' node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md

inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (nonRecursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
Expand Down Expand Up @@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab

symOcc :: HashMap Symbol Int
symOcc = getTableSymbolsMap tab

doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where
Expand All @@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecsReachable

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
where
nonRecs' =
if
Expand Down
Loading