Skip to content

Commit

Permalink
Fix apply expression (#52)
Browse files Browse the repository at this point in the history
* Well

* OK, fix one at a time

* Surprisingly ok

* One thing left to fix

* Nice

* Format

* Remove dead code
  • Loading branch information
danieljharvey authored Dec 10, 2024
1 parent 4d2f3c6 commit 03f34ee
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 108 deletions.
15 changes: 9 additions & 6 deletions wasm-calc12/src/Calc/Ability/Check.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import Calc.ExprUtils
import Calc.Types.Ability
import Calc.Types.Expr
import Calc.Types.Function
import Calc.Types.Identifier
import Calc.Types.Import
import Calc.Types.Module
import Calc.Types.ModuleAnnotations
Expand Down Expand Up @@ -159,13 +160,15 @@ abilityExpr (EBox ann a) = do
abilityExpr (EConstructor ann constructor as) = do
tell (S.singleton $ AllocateMemory ann)
EConstructor ann constructor <$> traverse abilityExpr as
abilityExpr (EApply ann fn args) = do
isImport <- asks (S.member fn . aeImportNames)
abilityExpr (EApply ann fn@(EVar _ (Identifier fnVar)) args) = do
let functionName = FunctionName fnVar
isImport <- asks (S.member functionName . aeImportNames)
if isImport
then tell (S.singleton $ CallImportedFunction ann fn)
then tell (S.singleton $ CallImportedFunction ann functionName)
else do
-- whatever abilities this function uses, we now use
functionAbilities <- lookupFunctionAbilities fn
-- if this name points at a function, whatever abilities
-- that function uses, we use
functionAbilities <- lookupFunctionAbilities functionName
tell functionAbilities
EApply ann fn <$> traverse abilityExpr args
EApply ann <$> abilityExpr fn <*> traverse abilityExpr args
abilityExpr other = bindExpr abilityExpr other
10 changes: 5 additions & 5 deletions wasm-calc12/src/Calc/Dependencies.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ getFunctionDependencies globalNames importNames (Function {fnFunctionName, fnBod
getExprDependencies :: S.Set Identifier -> S.Set FunctionName -> Expr ann -> S.Set Dependency
getExprDependencies globalNames importNames = snd . runWriter . go
where
go (EApply ann fnName args) = do
if S.member fnName importNames
then tell (S.singleton $ DepImport fnName)
else tell (S.singleton $ DepFunction fnName)
EApply ann fnName <$> traverse go args
go (EApply ann fnExpr@(EVar _ (Identifier fnName)) args) = do
if S.member (FunctionName fnName) importNames
then tell (S.singleton $ DepImport (FunctionName fnName))
else tell (S.singleton $ DepFunction (FunctionName fnName))
EApply ann <$> go fnExpr <*> traverse go args
go (ESet ann globalName value) = do
tell (S.singleton $ DepGlobal globalName)
ESet ann globalName <$> go value
Expand Down
20 changes: 8 additions & 12 deletions wasm-calc12/src/Calc/Linearity/Decorate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ import Calc.ExprUtils
import Calc.Linearity.Types
import Calc.TypeUtils
import Calc.Types.Expr
import Calc.Types.FunctionName
import Calc.Types.Identifier
import Calc.Types.Pattern
import Calc.Types.Type
import Control.Monad (unless, when)
import Control.Monad (unless)
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (second)
Expand All @@ -40,7 +39,7 @@ pushUses ::
(MonadState (LinearState ann) m) =>
M.Map Identifier (NE.NonEmpty (Linearity ann)) ->
m ()
pushUses uses =
pushUses uses = do
let pushForIdent ident =
traverse_ (\(Whole ann) -> recordUsesInState ident ann)
in traverse_ (uncurry pushForIdent) (M.toList uses)
Expand All @@ -54,7 +53,7 @@ recordUsesInState ::
Identifier ->
ann ->
m ()
recordUsesInState ident ann =
recordUsesInState ident ann = do
modify
( \ls ->
let f =
Expand All @@ -78,7 +77,9 @@ recordUse ::
m ()
recordUse ident ty = do
recordUsesInState ident (getOuterTypeAnnotation ty)
unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types
ignoreVars <- gets lsIgnoreVars
unless (S.member ident ignoreVars || isPrimitive ty) $
tell (M.singleton ident ty) -- we only want to track use of non-primitive types

-- run an action, giving it a new uses scope
-- then chop off the new values and return them
Expand Down Expand Up @@ -265,13 +266,8 @@ decorate (EIf ty predExpr thenExpr elseExpr) = do
<$> decorate predExpr
<*> pure (mapOuterExprAnnotation (second (const uniqueToElse)) decoratedThen)
<*> pure (mapOuterExprAnnotation (second (const uniqueToThen)) decoratedElse)
decorate (EApply ty fnName@(FunctionName inner) args) = do
-- if we know about the var, assume it's a lambda not a built in function
let identifier = Identifier inner
isVar <- gets (M.member (UserDefined identifier) . lsVars)
when isVar $
recordUse (Identifier inner) ty
EApply (ty, Nothing) fnName <$> traverse decorate args
decorate (EApply ty fn args) = do
EApply (ty, Nothing) <$> decorate fn <*> traverse decorate args
decorate (ETuple ty a as) =
ETuple (ty, Nothing) <$> decorate a <*> traverse decorate as
decorate (EBox ty a) =
Expand Down
4 changes: 3 additions & 1 deletion wasm-calc12/src/Calc/Linearity/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import Calc.Types.Identifier
import Calc.Types.Type
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import qualified Data.Set as S
import GHC.Natural

data Drops ann
Expand All @@ -40,6 +41,7 @@ data UserDefined a = UserDefined a | Internal a
data LinearState ann = LinearState
{ lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann),
lsUses :: NE.NonEmpty (M.Map Identifier (NE.NonEmpty (Linearity ann))),
lsFresh :: Natural
lsFresh :: Natural,
lsIgnoreVars :: S.Set Identifier
}
deriving stock (Eq, Ord, Show, Functor)
9 changes: 6 additions & 3 deletions wasm-calc12/src/Calc/Linearity/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import Data.Foldable (traverse_)
import Data.Functor (($>))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import qualified Data.Set as S

getLinearityAnnotation :: Linearity ann -> ann
getLinearityAnnotation (Whole ann) = ann
Expand Down Expand Up @@ -74,7 +75,7 @@ getFunctionUses ::
(Show ann) =>
Function (Type ann) ->
(Expr (Type ann, Maybe (Drops ann)), LinearState ann)
getFunctionUses (Function {fnBody, fnArgs}) =
getFunctionUses (Function {fnFunctionName = FunctionName fnName, fnBody, fnArgs}) =
fst $ runIdentity $ runWriterT $ runStateT action initialState
where
action = decorate fnBody
Expand All @@ -83,7 +84,8 @@ getFunctionUses (Function {fnBody, fnArgs}) =
LinearState
{ lsVars = initialVars,
lsUses = NE.singleton mempty,
lsFresh = 0
lsFresh = 0,
lsIgnoreVars = S.singleton (Identifier fnName) -- don't count recursive calls
}

initialVars =
Expand All @@ -108,5 +110,6 @@ getGlobalUses (Global {glbExpr}) =
LinearState
{ lsVars = mempty,
lsUses = NE.singleton mempty,
lsFresh = 0
lsFresh = 0,
lsIgnoreVars = mempty
}
19 changes: 14 additions & 5 deletions wasm-calc12/src/Calc/Parser/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,22 @@ varParser =
addLocation $
EVar mempty <$> identifierParser

applyFuncParser :: Parser (Expr Annotation)
applyFuncParser = do
varParser <?> "term"

applyParser :: Parser (Expr Annotation)
applyParser = addLocation $ do
fnName <- functionNameParser
stringLiteral "("
args <- sepEndBy exprParserInternal (stringLiteral ",")
stringLiteral ")"
pure (EApply mempty fnName args)
func <- applyFuncParser
let argParser = do
stringLiteral "("
args <- sepEndBy exprParserInternal (stringLiteral ",")
stringLiteral ")"
pure args
let argParser' :: Parser [[ParserExpr]]
argParser' = (: []) <$> argParser
args <- chainl1 argParser' (pure (<>))
pure $ foldl (EApply mempty) func args

tupleParser :: Parser (Expr Annotation)
tupleParser = label "tuple" $
Expand Down
15 changes: 14 additions & 1 deletion wasm-calc12/src/Calc/Parser/Shared.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE OverloadedStrings #-}

module Calc.Parser.Shared
( inBrackets,
( chainl1,
inBrackets,
myLexeme,
withLocation,
stringLiteral,
Expand Down Expand Up @@ -59,3 +60,15 @@ maybePred parser predicate' = try $ do
case predicate' a of
Just b -> pure b
_ -> fail $ T.unpack $ "Predicate did not hold for " <> T.pack (show a)

-- | stolen from Parsec, allows parsing infix expressions without recursion
-- death
chainl1 :: Parser a -> Parser (a -> a -> a) -> Parser a
chainl1 p op = do x <- p; rest x
where
rest x =
do
f <- op
y <- p
rest (f x y)
<|> return x
48 changes: 38 additions & 10 deletions wasm-calc12/src/Calc/Typecheck/Elaborate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Calc.Types.Test
import Calc.Types.Type
import Control.Monad.State
import Data.Functor
import qualified Data.Map.Strict as M
import qualified Data.Set as S

elaborateModule ::
Expand Down Expand Up @@ -51,6 +52,22 @@ elaborateModule
tceDataTypes = arrangeDataTypes mdDataTypes
}

-- statically provide types of all functions in scope
let functionsInScope =
foldMap
( \(Function {fnFunctionName, fnAnn, fnArgs, fnReturnType}) ->
M.singleton
fnFunctionName
( TFunction fnAnn (faType <$> fnArgs) fnReturnType
)
)
mdFunctions

let importsInScope =
foldMap
(\(Import {impImportName, impAnn, impArgs, impReturnType}) -> M.singleton impImportName (TFunction impAnn (iaType <$> impArgs) impReturnType))
mdImports

runTypecheckM typecheckEnv $ do
globals <-
traverse
Expand All @@ -73,7 +90,7 @@ elaborateModule
functions <-
traverse
( \fn -> do
elabFn <- elaborateFunction fn
elabFn <- elaborateFunction (functionsInScope <> importsInScope) fn
storeFunction
(fnFunctionName elabFn)
(S.fromList $ fnGenerics fn)
Expand All @@ -82,7 +99,7 @@ elaborateModule
)
mdFunctions

tests <- traverse elaborateTest mdTests
tests <- traverse (elaborateTest functionsInScope) mdTests

pure $
Module
Expand All @@ -100,9 +117,14 @@ elaborateDataType (Data dtName vars cons) =

-- check a test expression has type `Bool`
-- later we'll also check it does not use any imports
elaborateTest :: Test ann -> TypecheckM ann (Test (Type ann))
elaborateTest (Test {tesAnn, tesName, tesExpr}) = do
elabExpr <- check (TPrim tesAnn TBool) tesExpr
elaborateTest :: M.Map FunctionName (Type ann) -> Test ann -> TypecheckM ann (Test (Type ann))
elaborateTest functionsInScope (Test {tesAnn, tesName, tesExpr}) = do
elabExpr <-
withFunctionEnv
mempty
functionsInScope
mempty
(check (TPrim tesAnn TBool) tesExpr)

pure $
Test
Expand Down Expand Up @@ -173,9 +195,11 @@ checkAndSubstitute ty expr = do
pure $ substitute unified <$> exprA

elaborateFunction ::
M.Map FunctionName (Type ann) ->
Function ann ->
TypecheckM ann (Function (Type ann))
elaborateFunction
functionsInScope
( Function
{ fnPublic,
fnAnn,
Expand All @@ -187,15 +211,17 @@ elaborateFunction
fnBody
}
) = do
-- store current function so we can recursively call ourselves
storeFunction
fnFunctionName
(S.fromList fnGenerics)
(TFunction fnAnn (faType <$> fnArgs) fnReturnType)
-- include current function with arguments so we can recursively call ourselves
let tyCurrentFunction =
TFunction fnAnn (faType <$> fnArgs) fnReturnType

let functionsWithCurrent =
M.insert fnFunctionName tyCurrentFunction functionsInScope

exprA <-
withFunctionEnv
fnArgs
functionsWithCurrent
(S.fromList fnGenerics)
(checkAndSubstitute fnReturnType fnBody)

Expand All @@ -208,11 +234,13 @@ elaborateFunction
}
)
<$> fnArgs

let tyFn =
TFunction
fnAnn
(faType <$> fnArgs)
(getOuterAnnotation exprA)

pure
( Function
{ fnAnn = tyFn,
Expand Down
19 changes: 15 additions & 4 deletions wasm-calc12/src/Calc/Typecheck/Helpers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,29 @@ withLambdaEnv args =
-- | temporarily add function arguments and generics into the Reader env
withFunctionEnv ::
[FunctionArg ann] ->
M.Map FunctionName (Type ann) ->
S.Set TypeVar ->
TypecheckM ann a ->
TypecheckM ann a
withFunctionEnv args generics =
let identifiers =
withFunctionEnv args functionsInScope generics =
let identifiersFromArgs =
fmap
(\FunctionArg {faName = ArgumentName arg, faType} -> (Identifier arg, faType))
( \FunctionArg {faName = ArgumentName arg, faType} ->
(Identifier arg, faType)
)
args
identifiersFromFunctions =
( \(FunctionName fnName, fnType) ->
(Identifier fnName, fnType)
)
<$> M.toList functionsInScope
in local
( \tce ->
tce
{ tceVars = tceVars tce <> HM.fromList identifiers,
{ tceVars =
tceVars tce
<> HM.fromList identifiersFromFunctions
<> HM.fromList identifiersFromArgs,
tceGenerics = generics
}
)
Expand Down
Loading

0 comments on commit 03f34ee

Please sign in to comment.