Skip to content

Commit

Permalink
Nice
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljharvey committed Dec 28, 2023
1 parent e673bea commit 015aa5e
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 88 deletions.
19 changes: 17 additions & 2 deletions wasm-calc5/src/Calc/TypeUtils.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
module Calc.TypeUtils (bindType, mapOuterTypeAnnotation, getOuterTypeAnnotation) where
module Calc.TypeUtils
( bindType,
mapType,
mapOuterTypeAnnotation,
getOuterTypeAnnotation,
)
where

import Calc.Types.Type
import Control.Monad.Identity

getOuterTypeAnnotation :: Type ann -> ann
getOuterTypeAnnotation (TPrim ann _) = ann
Expand All @@ -16,7 +23,15 @@ mapOuterTypeAnnotation f (TTuple ann a b) = TTuple (f ann) a b
mapOuterTypeAnnotation f (TVar ann v) = TVar (f ann) v
mapOuterTypeAnnotation f (TUnificationVar ann v) = TUnificationVar (f ann) v

bindType :: (Applicative m) => (Type ann -> m (Type ann)) -> Type ann -> m (Type ann)
mapType :: (Type ann -> Type ann) -> Type ann -> Type ann
mapType f ty =
runIdentity (bindType (pure . f) ty)

bindType ::
(Applicative m) =>
(Type ann -> m (Type ann)) ->
Type ann ->
m (Type ann)
bindType _ (TPrim ann p) =
pure $ TPrim ann p
bindType f (TFunction ann a b) =
Expand Down
79 changes: 50 additions & 29 deletions wasm-calc5/src/Calc/Typecheck/Elaborate.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Calc.Typecheck.Elaborate
Expand All @@ -9,23 +9,25 @@ module Calc.Typecheck.Elaborate
)
where

import Calc.ExprUtils
import Calc.TypeUtils
import Calc.Typecheck.Error
import Calc.Typecheck.Helpers
import Calc.Typecheck.Types
import Calc.Types.Expr
import Calc.Types.Function
import Calc.Types.Module
import Calc.Types.Prim
import Calc.Types.Type
import Control.Monad (when, zipWithM)
import Control.Monad.Except
import Data.Bifunctor (second)
import Data.Functor
import qualified Data.List as List
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S
import Calc.ExprUtils
import Calc.Typecheck.Error
import Calc.Typecheck.Helpers
import Calc.Typecheck.Substitute
import Calc.Typecheck.Types
import Calc.Types.Expr
import Calc.Types.Function
import Calc.Types.Module
import Calc.Types.Prim
import Calc.Types.Type
import Calc.TypeUtils
import Control.Monad (when, zipWithM)
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Functor
import qualified Data.List as List
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S

elaborateModule ::
forall ann.
Expand All @@ -42,16 +44,32 @@ elaborateModule (Module {mdFunctions, mdExpr}) =
)
mdFunctions

Module fns <$> infer mdExpr
Module fns <$> inferAndSubstitute mdExpr

inferAndSubstitute ::
Expr ann ->
TypecheckM ann (Expr (Type ann))
inferAndSubstitute expr = do
exprA <- infer expr
unified <- gets tcsUnified
pure $ substitute unified <$> exprA

elaborateFunction ::
Function ann ->
TypecheckM ann (Function (Type ann))
elaborateFunction (Function {fnAnn, fnArgs, fnGenerics, fnFunctionName, fnBody}) = do
exprA <- withFunctionEnv fnArgs (S.fromList fnGenerics) (infer fnBody)
_ <- error "we should now substitute from the state to exprA to fill it what we learned"
let argsA = fmap (second (\ty -> fmap (const ty) ty)) fnArgs
let tyFn = TFunction fnAnn (snd <$> fnArgs) (getOuterAnnotation exprA)
exprA <-
withFunctionEnv
fnArgs
(S.fromList fnGenerics)
(inferAndSubstitute fnBody)
let argsA =
second (\ty -> fmap (const ty) ty) <$> fnArgs
let tyFn =
TFunction
fnAnn
(snd <$> fnArgs)
(getOuterAnnotation exprA)
pure
( Function
{ fnAnn = tyFn,
Expand All @@ -76,7 +94,9 @@ unify (TUnificationVar _ nat) b = do
unify a (TUnificationVar _ nat) = do
unifyVariableWithType nat a
unify (TFunction ann argA bodyA) (TFunction _ argB bodyB) =
TFunction ann <$> zipWithM unify argA argB <*> unify bodyA bodyB
TFunction ann
<$> zipWithM unify argA argB
<*> unify bodyA bodyB
unify (TTuple ann a as) (TTuple _ b bs) =
TTuple ann
<$> unify a b
Expand All @@ -96,7 +116,7 @@ inferIf ann predExpr thenExpr elseExpr = do
predA <- infer predExpr
case getOuterAnnotation predA of
(TPrim _ TBool) -> pure ()
otherType -> throwError (PredicateIsNotBoolean ann otherType)
otherType -> throwError (PredicateIsNotBoolean ann otherType)
thenA <- infer thenExpr
elseA <- check (getOuterAnnotation thenA) elseExpr
pure (EIf (getOuterAnnotation elseA) predA thenA elseA)
Expand Down Expand Up @@ -157,7 +177,8 @@ inferApply ann fnName args = do
(length args /= length tArgs)
(throwError $ FunctionArgumentLengthMismatch ann (length tArgs) (length args))
elabArgs <- zipWithM check tArgs args -- check each arg against type
pure (tReturn, elabArgs)
unified <- gets tcsUnified
pure (substitute unified tReturn, elabArgs)
_ -> throwError $ NonFunctionTypeFound ann fn
pure (EApply (ty $> ann) fnName elabArgs)

Expand Down Expand Up @@ -194,8 +215,8 @@ infer (EInfix ann op a b) =
inferInfix ann op a b

typePrimFromPrim :: Prim -> TypePrim
typePrimFromPrim (PInt _) = TInt
typePrimFromPrim (PBool _) = TBool
typePrimFromPrim (PInt _) = TInt
typePrimFromPrim (PBool _) = TBool
typePrimFromPrim (PFloat _) = TFloat

typeFromPrim :: ann -> Prim -> Type ann
Expand Down
46 changes: 32 additions & 14 deletions wasm-calc5/src/Calc/Typecheck/Generalise.hs
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
module Calc.Typecheck.Generalise (generalise) where

import Calc.TypeUtils (bindType)
import Calc.Typecheck.Types
import Calc.Types.Type
import Calc.Types.TypeVar
import Control.Monad.State
import qualified Data.Set as S
import GHC.Natural

import Calc.Typecheck.Types
import Calc.Types.Type
import Calc.Types.TypeVar
import Calc.TypeUtils (mapType)
import Control.Monad.State
import qualified Data.HashMap.Strict as HM
import qualified Data.Set as S
import GHC.Natural

-- get a nice new number
freshUnificationVariable :: TypecheckM ann Natural
freshUnificationVariable = do
current <- gets tcsUnique
modify (\tcs -> tcs {tcsUnique = current + 1})
pure current
modify (\tcs -> tcs {tcsUnique = tcsUnique tcs + 1})
gets tcsUnique

allFresh :: S.Set TypeVar -> TypecheckM ann (HM.HashMap TypeVar Natural)
allFresh generics =
let freshOne typeVar =
HM.singleton typeVar <$> freshUnificationVariable
in mconcat <$> traverse freshOne (S.toList generics)

-- given a type, replace anything that should be generic with unification
-- variables so that we know to replace them with types easily
generalise :: S.Set TypeVar -> Type ann -> TypecheckM ann (Type ann)
generalise generics (TVar ann var)
| S.member var generics =
TUnificationVar ann <$> freshUnificationVariable
generalise generics other = bindType (generalise generics) other
generalise generics ty
= do
fresh <- allFresh generics
pure $ generaliseInternal fresh ty

-- given a type, replace anything that should be generic with unification
-- variables so that we know to replace them with types easily
generaliseInternal :: HM.HashMap TypeVar Natural -> Type ann -> Type ann
generaliseInternal fresh (TVar ann var) =
case HM.lookup var fresh of
Just nat ->
TUnificationVar ann nat
Nothing -> error "oh no generalise error"
generaliseInternal fresh other =
mapType (generaliseInternal fresh) other
61 changes: 38 additions & 23 deletions wasm-calc5/src/Calc/Typecheck/Helpers.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NamedFieldPuns #-}

module Calc.Typecheck.Helpers
( runTypecheckM,
Expand All @@ -11,21 +11,20 @@ module Calc.Typecheck.Helpers
storeFunction,
)
where

import Calc.Typecheck.Error
import Calc.Typecheck.Generalise
import Calc.Typecheck.Types
import Calc.Types.Function
import Calc.Types.Identifier
import Calc.Types.Type
import Calc.Types.TypeVar
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import qualified Data.HashMap.Strict as HM
import qualified Data.Set as S
import GHC.Natural
import Calc.Typecheck.Error
import Calc.Typecheck.Generalise
import Calc.Typecheck.Types
import Calc.Types.Function
import Calc.Types.Identifier
import Calc.Types.Type
import Calc.Types.TypeVar
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (first)
import qualified Data.HashMap.Strict as HM
import qualified Data.Set as S
import GHC.Natural

runTypecheckM ::
TypecheckEnv ann ->
Expand All @@ -34,7 +33,12 @@ runTypecheckM ::
runTypecheckM env action =
evalStateT
(runReaderT (getTypecheckM action) env)
(TypecheckState {tcsFunctions = mempty, tcsUnique = 0, tcsUnified = mempty})
( TypecheckState
{ tcsFunctions = mempty,
tcsUnique = 0,
tcsUnified = mempty
}
)

storeFunction ::
FunctionName ->
Expand All @@ -46,16 +50,21 @@ storeFunction fnName generics ty =
( \tcs ->
tcs
{ tcsFunctions =
HM.insert fnName (TypeScheme ty generics) (tcsFunctions tcs)
HM.insert
fnName
(TypeScheme ty generics)
(tcsFunctions tcs)
}
)

-- | look up a saved identifier "in the environment"
lookupFunction :: ann -> FunctionName -> TypecheckM ann (Type ann)
lookupFunction :: ann -> FunctionName -> TypecheckM ann (Type ann)
lookupFunction ann fnName = do
maybeType <- gets (HM.lookup fnName . tcsFunctions)
maybeType <- gets (HM.lookup fnName . tcsFunctions)

case maybeType of
Just (TypeScheme {tsType, tsGenerics}) -> generalise tsGenerics tsType
Just (TypeScheme {tsType, tsGenerics}) ->
generalise tsGenerics tsType
Nothing -> do
allFunctions <- gets (HM.keysSet . tcsFunctions)
throwError (FunctionNotFound ann fnName allFunctions)
Expand Down Expand Up @@ -99,7 +108,10 @@ withFunctionEnv args generics =

-- | given a unification variable, either save it and return the type
-- or explode because we've already unified it with something else
unifyVariableWithType :: Natural -> Type ann -> TypecheckM ann (Type ann)
unifyVariableWithType ::
Natural ->
Type ann ->
TypecheckM ann (Type ann)
unifyVariableWithType nat ty =
do
existing <- gets (HM.lookup nat . tcsUnified)
Expand All @@ -108,7 +120,10 @@ unifyVariableWithType nat ty =
-- this is the first match, store it and return the passed-in type
modify
( \tcs ->
tcs {tcsUnified = tcsUnified tcs <> HM.singleton nat ty}
tcs
{ tcsUnified =
HM.insert nat ty (tcsUnified tcs)
}
)
pure ty
Just _existingTy -> do
Expand Down
15 changes: 15 additions & 0 deletions wasm-calc5/src/Calc/Typecheck/Substitute.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module Calc.Typecheck.Substitute (substitute) where

import Calc.Types.Type
import Calc.TypeUtils
import qualified Data.HashMap.Strict as HM
import GHC.Natural

substitute :: HM.HashMap Natural (Type ann) ->
Type ann -> Type ann
substitute subs (TUnificationVar _ nat) =
case HM.lookup nat subs of
Just ty -> ty
Nothing -> error $ "Could not find unification var for " <> show nat
substitute subs other
= mapType (substitute subs) other
5 changes: 4 additions & 1 deletion wasm-calc5/src/Calc/Typecheck/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ data TypecheckState ann = TypecheckState
}
deriving stock (Eq, Ord, Show)

data TypeScheme ann = TypeScheme {tsType :: Type ann, tsGenerics :: S.Set TypeVar}
data TypeScheme ann = TypeScheme
{ tsType :: Type ann,
tsGenerics :: S.Set TypeVar
}
deriving stock (Eq, Ord, Show)

newtype TypecheckM ann a = TypecheckM
Expand Down
Loading

0 comments on commit 015aa5e

Please sign in to comment.