Skip to content

Commit

Permalink
Unbox operator
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljharvey committed Dec 31, 2023
1 parent e1857bc commit 6d74c5b
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 131 deletions.
38 changes: 19 additions & 19 deletions wasm-calc5/src/Calc/ExprUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,33 @@ module Calc.ExprUtils
)
where

import Calc.Types
import Calc.Types

-- | get the annotation in the first leaf found in an `Expr`.
-- useful for getting the overall type of an expression
getOuterAnnotation :: Expr ann -> ann
getOuterAnnotation (EInfix ann _ _ _) = ann
getOuterAnnotation (EPrim ann _) = ann
getOuterAnnotation (EIf ann _ _ _) = ann
getOuterAnnotation (EVar ann _) = ann
getOuterAnnotation (EApply ann _ _) = ann
getOuterAnnotation (ETuple ann _ _) = ann
getOuterAnnotation (ETupleAccess ann _ _) = ann
getOuterAnnotation (EBox ann _) = ann
getOuterAnnotation (EInfix ann _ _ _) = ann
getOuterAnnotation (EPrim ann _) = ann
getOuterAnnotation (EIf ann _ _ _) = ann
getOuterAnnotation (EVar ann _) = ann
getOuterAnnotation (EApply ann _ _) = ann
getOuterAnnotation (ETuple ann _ _) = ann
getOuterAnnotation (EContainerAccess ann _ _) = ann
getOuterAnnotation (EBox ann _) = ann

-- | modify the outer annotation of an expression
-- useful for adding line numbers during parsing
mapOuterExprAnnotation :: (ann -> ann) -> Expr ann -> Expr ann
mapOuterExprAnnotation f expr' =
case expr' of
EInfix ann a b c -> EInfix (f ann) a b c
EPrim ann a -> EPrim (f ann) a
EIf ann a b c -> EIf (f ann) a b c
EVar ann a -> EVar (f ann) a
EApply ann a b -> EApply (f ann) a b
ETuple ann a b -> ETuple (f ann) a b
ETupleAccess ann a b -> ETupleAccess (f ann) a b
EBox ann a -> EBox (f ann) a
EInfix ann a b c -> EInfix (f ann) a b c
EPrim ann a -> EPrim (f ann) a
EIf ann a b c -> EIf (f ann) a b c
EVar ann a -> EVar (f ann) a
EApply ann a b -> EApply (f ann) a b
ETuple ann a b -> ETuple (f ann) a b
EContainerAccess ann a b -> EContainerAccess (f ann) a b
EBox ann a -> EBox (f ann) a

-- | Given a function that changes `Expr` values to `m Expr`, apply it throughout
-- an AST tree
Expand All @@ -50,6 +50,6 @@ bindExpr f (EIf ann predExpr thenExpr elseExpr) =
EIf ann <$> f predExpr <*> f thenExpr <*> f elseExpr
bindExpr f (ETuple ann a as) =
ETuple ann <$> f a <*> traverse f as
bindExpr f (ETupleAccess ann a nat) =
ETupleAccess ann <$> f a <*> pure nat
bindExpr f (EContainerAccess ann a nat) =
EContainerAccess ann <$> f a <*> pure nat
bindExpr f (EBox ann a) = EBox ann <$> f a
32 changes: 16 additions & 16 deletions wasm-calc5/src/Calc/Interpreter.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE NamedFieldPuns #-}

module Calc.Interpreter
( runInterpreter,
Expand All @@ -13,15 +13,15 @@ module Calc.Interpreter
)
where

import Calc.Types
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Coerce
import qualified Data.List.NonEmpty as NE
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import GHC.Natural
import Calc.Types
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Coerce
import qualified Data.List.NonEmpty as NE
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import GHC.Natural

-- | type for interpreter state
newtype InterpreterState ann = InterpreterState
Expand Down Expand Up @@ -142,15 +142,15 @@ interpret (ETuple ann a as) = do
aA <- interpret a
asA <- traverse interpret as
pure (ETuple ann aA asA)
interpret (ETupleAccess _ tup index) = do
interpret (EContainerAccess _ tup index) = do
aTup <- interpret tup
interpretTupleAccess aTup index
interpret (EIf ann predExpr thenExpr elseExpr) = do
predA <- interpret predExpr
case predA of
(EPrim _ (PBool True)) -> interpret thenExpr
(EPrim _ (PBool True)) -> interpret thenExpr
(EPrim _ (PBool False)) -> interpret elseExpr
other -> throwError (NonBooleanPredicate ann other)
other -> throwError (NonBooleanPredicate ann other)
interpret (EBox ann a) =
EBox ann <$> interpret a

Expand All @@ -159,7 +159,7 @@ interpretTupleAccess wholeExpr@(ETuple _ fstExpr restExpr) index = do
let items = zip ([0 ..] :: [Natural]) (fstExpr : NE.toList restExpr)
case lookup (index - 1) items of
Just expr -> pure expr
Nothing -> throwError (AccessOutsideTupleBounds wholeExpr index)
Nothing -> throwError (AccessOutsideTupleBounds wholeExpr index)
interpretTupleAccess wholeExpr@(EBox _ innerExpr) index = do
case index of
1 -> interpret innerExpr
Expand Down
54 changes: 36 additions & 18 deletions wasm-calc5/src/Calc/Parser/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,26 @@

module Calc.Parser.Expr (exprParser) where

import Calc.Parser.Identifier
import Calc.Parser.Primitives
import Calc.Parser.Shared
import Calc.Parser.Types
import Calc.Types.Annotation
import Calc.Types.Expr
import Control.Monad.Combinators.Expr
import Data.Foldable (foldl')
import qualified Data.List.NonEmpty as NE
import qualified Data.Text as T
import GHC.Natural
import Text.Megaparsec
import Calc.Parser.Identifier
import Calc.Parser.Primitives
import Calc.Parser.Shared
import Calc.Parser.Types
import Calc.Types.Annotation
import Calc.Types.Expr
import Control.Monad.Combinators.Expr
import Data.Foldable (foldl')
import qualified Data.List.NonEmpty as NE
import qualified Data.Text as T
import GHC.Natural
import Text.Megaparsec

exprParser :: Parser (Expr Annotation)
exprParser = addLocation (makeExprParser exprPart table) <?> "expression"

exprPart :: Parser (Expr Annotation)
exprPart =
try tupleAccessParser
try unboxParser
<|> try containerAccessParser
<|> try tupleParser
<|> boxParser
<|> inBrackets (addLocation exprParser)
Expand Down Expand Up @@ -69,12 +70,29 @@ tupleParser = label "tuple" $
neArgs <- NE.fromList <$> sepBy1 exprParser (stringLiteral ",")
neTail <- case NE.nonEmpty (NE.tail neArgs) of
Just ne -> pure ne
_ -> fail "Expected at least two items in a tuple"
_ -> fail "Expected at least two items in a tuple"
_ <- stringLiteral ")"
pure (ETuple mempty (NE.head neArgs) neTail)

tupleAccessParser :: Parser (Expr Annotation)
tupleAccessParser =
unboxParser :: Parser (Expr Annotation)
unboxParser =
let tupParser :: Parser (Expr Annotation)
tupParser =
try containerAccessParser <|>
try tupleParser
<|> try applyParser
<|> try varParser
<|> boxParser
in label "unbox" $
addLocation $ do
tup <- tupParser
_ <- stringLiteral "!"
pure $
EContainerAccess mempty tup 1


containerAccessParser :: Parser (Expr Annotation)
containerAccessParser =
let natParser :: Parser Natural
natParser = myLexeme (fromIntegral <$> intParser)

Expand All @@ -84,14 +102,14 @@ tupleAccessParser =
<|> try applyParser
<|> try varParser
<|> boxParser
in label "tuple access" $
in label "container access" $
addLocation $ do
tup <- tupParser
_ <- stringLiteral "."
accesses <- sepBy1 natParser (stringLiteral ".")
pure $
foldl'
( ETupleAccess mempty
( EContainerAccess mempty
)
tup
accesses
Expand Down
54 changes: 27 additions & 27 deletions wasm-calc5/src/Calc/Typecheck/Elaborate.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Calc.Typecheck.Elaborate
Expand All @@ -9,25 +9,25 @@ module Calc.Typecheck.Elaborate
)
where

import Calc.ExprUtils
import Calc.TypeUtils
import Calc.Typecheck.Error
import Calc.Typecheck.Helpers
import Calc.Typecheck.Substitute
import Calc.Typecheck.Types
import Calc.Types.Expr
import Calc.Types.Function
import Calc.Types.Module
import Calc.Types.Prim
import Calc.Types.Type
import Control.Monad (when, zipWithM)
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Functor
import qualified Data.List as List
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S
import Calc.ExprUtils
import Calc.Typecheck.Error
import Calc.Typecheck.Helpers
import Calc.Typecheck.Substitute
import Calc.Typecheck.Types
import Calc.Types.Expr
import Calc.Types.Function
import Calc.Types.Module
import Calc.Types.Prim
import Calc.Types.Type
import Calc.TypeUtils
import Control.Monad (when, zipWithM)
import Control.Monad.Except
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Functor
import qualified Data.List as List
import qualified Data.List.NonEmpty as NE
import qualified Data.Set as S

elaborateModule ::
forall ann.
Expand Down Expand Up @@ -115,7 +115,7 @@ inferIf ann predExpr thenExpr elseExpr = do
predA <- infer predExpr
case getOuterAnnotation predA of
(TPrim _ TBool) -> pure ()
otherType -> throwError (PredicateIsNotBoolean ann otherType)
otherType -> throwError (PredicateIsNotBoolean ann otherType)
thenA <- infer thenExpr
elseA <- check (getOuterAnnotation thenA) elseExpr
pure (EIf (getOuterAnnotation elseA) predA thenA elseA)
Expand Down Expand Up @@ -170,7 +170,7 @@ checkApplyArg ty@(TUnificationVar {}) expr = do
tyExpr <- infer expr
case getOuterAnnotation tyExpr of
p@TPrim {} -> throwError (NonBoxedGenericValue (getOuterTypeAnnotation p) p)
_other -> check ty expr
_other -> check ty expr
checkApplyArg ty expr = check ty expr

-- | if our return type is polymorphic, our concrete type should not be a
Expand Down Expand Up @@ -233,14 +233,14 @@ infer (ETuple ann fstExpr restExpr) = do
(getOuterAnnotation <$> typedRest)
)
pure $ ETuple typ typedFst typedRest
infer (ETupleAccess ann tup index) = do
infer (EContainerAccess ann tup index) = do
tyTup <- infer tup
case getOuterAnnotation tyTup of
TContainer _ tyAll ->
let tyNumbered = zip ([0 ..] :: [Int]) (NE.toList tyAll)
in case List.lookup (fromIntegral $ index - 1) tyNumbered of
Just ty ->
pure (ETupleAccess ty tyTup index)
pure (EContainerAccess ty tyTup index)
Nothing -> throwError $ AccessingOutsideTupleBounds ann (getOuterAnnotation tyTup) index
otherTy -> throwError $ AccessingNonTuple ann otherTy
infer (EApply ann fnName args) =
Expand All @@ -252,8 +252,8 @@ infer (EInfix ann op a b) =
inferInfix ann op a b

typePrimFromPrim :: Prim -> TypePrim
typePrimFromPrim (PInt _) = TInt
typePrimFromPrim (PBool _) = TBool
typePrimFromPrim (PInt _) = TInt
typePrimFromPrim (PBool _) = TBool
typePrimFromPrim (PFloat _) = TFloat

typeFromPrim :: ann -> Prim -> Type ann
Expand Down
26 changes: 13 additions & 13 deletions wasm-calc5/src/Calc/Types/Expr.hs
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}

module Calc.Types.Expr (Expr (..), Op (..)) where

import Calc.Types.FunctionName
import Calc.Types.Identifier
import Calc.Types.Prim
import qualified Data.List.NonEmpty as NE
import GHC.Natural
import Prettyprinter ((<+>))
import qualified Prettyprinter as PP
import Calc.Types.FunctionName
import Calc.Types.Identifier
import Calc.Types.Prim
import qualified Data.List.NonEmpty as NE
import GHC.Natural
import Prettyprinter ((<+>))
import qualified Prettyprinter as PP

data Expr ann
= EPrim ann Prim
Expand All @@ -19,7 +19,7 @@ data Expr ann
| EVar ann Identifier
| EApply ann FunctionName [Expr ann]
| ETuple ann (Expr ann) (NE.NonEmpty (Expr ann))
| ETupleAccess ann (Expr ann) Natural
| EContainerAccess ann (Expr ann) Natural
| EBox ann (Expr ann)
deriving stock (Eq, Ord, Show, Functor, Foldable, Traversable)

Expand All @@ -45,7 +45,7 @@ instance PP.Pretty (Expr ann) where
where
tupleItems :: a -> NE.NonEmpty a -> [a]
tupleItems b bs = b : NE.toList bs
pretty (ETupleAccess _ tup nat) =
pretty (EContainerAccess _ tup nat) =
PP.pretty tup <> "." <> PP.pretty nat
pretty (EBox _ inner) =
"Box(" <> PP.pretty inner <> ")"
Expand All @@ -59,7 +59,7 @@ data Op

-- how to print `Op` values
instance PP.Pretty Op where
pretty OpAdd = "+"
pretty OpAdd = "+"
pretty OpMultiply = "*"
pretty OpSubtract = "-"
pretty OpEquals = "=="
pretty OpEquals = "=="
Loading

0 comments on commit 6d74c5b

Please sign in to comment.