Skip to content

Commit

Permalink
Some sort of fix
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljharvey committed Jun 2, 2024
1 parent 07129e1 commit 6b33be5
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 29 deletions.
22 changes: 14 additions & 8 deletions wasm-calc9/src/Calc/Linearity/Types.hs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Calc.Linearity.Types
( Linearity (..),
LinearityType (..),
LinearState (..),
UserDefined (..)
)
where

import Calc.Types.Identifier
import qualified Data.Map as M
import GHC.Natural
import Calc.Types.Identifier
import qualified Data.Map as M
import GHC.Natural

-- | Are we using the whole type or bits of it?
-- this distinction will be gone once we can destructure types instead,
Expand All @@ -26,9 +27,14 @@ newtype Linearity ann
data LinearityType = LTPrimitive | LTBoxed
deriving stock (Eq, Ord, Show)

-- | differentiate between names provided by a user, and variables
-- created during linearity check to allow us to drop unnamed items
data UserDefined a = UserDefined a | Internal a
deriving stock (Eq,Ord,Show,Functor)

data LinearState ann = LinearState
{ lsVars :: M.Map Identifier (LinearityType, ann),
lsUses :: [(Identifier, Linearity ann)],
{ lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann),
lsUses :: [(Identifier, Linearity ann)],
lsFresh :: Natural
}
deriving stock (Eq, Ord, Show, Functor)
39 changes: 27 additions & 12 deletions wasm-calc9/src/Calc/Linearity/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ validateFunction fn =

validate :: LinearState ann -> Either (LinearityError ann) ()
validate (LinearState {lsVars, lsUses}) =
let validateFunctionItem (ident, (linearity, _ann)) =
let validateFunctionItem (Internal _,_) = Right ()
validateFunctionItem (UserDefined ident, (linearity, ann)) =
let completeUses = filterCompleteUses lsUses ident
in case linearity of
LTPrimitive -> Right ()
{- if null completeUses
LTPrimitive ->
if null completeUses
then Left (NotUsed ann ident)
else Right () -}
else Right ()
LTBoxed ->
case length completeUses of
0 -> Right ()
-- Left (NotUsed ann ident)
0 -> Left (NotUsed ann ident)
1 -> Right ()
_more ->
Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident)
Expand Down Expand Up @@ -122,7 +122,7 @@ getFunctionUses (Function {fnBody, fnArgs}) =
initialVars =
foldMap
( \(FunctionArg {faAnn, faName = ArgumentName arg, faType}) ->
M.singleton (Identifier arg) $ case faType of
M.singleton (UserDefined (Identifier arg)) $ case faType of
TPrim {} -> (LTPrimitive, getOuterTypeAnnotation faAnn)
_ -> (LTBoxed, getOuterTypeAnnotation faAnn)
)
Expand Down Expand Up @@ -169,7 +169,7 @@ addLetBinding (PVar ty ident) = do
ls
{ lsVars =
M.insert
ident
(UserDefined ident)
( if isPrimitive ty then LTPrimitive else LTBoxed,
getOuterTypeAnnotation ty
)
Expand All @@ -179,15 +179,30 @@ addLetBinding (PVar ty ident) = do
pure $ PVar (ty, Nothing) ident
addLetBinding (PWildcard ty) = do
i <- getFresh
let name = Identifier $ "_fresh_name" <> T.pack (show i)
addLetBinding $ PVar ty name
let ident = Identifier $ "_fresh_name" <> T.pack (show i)
modify
( \ls ->
ls
{ lsVars =
M.insert
(Internal ident)
( if isPrimitive ty then LTPrimitive else LTBoxed,
getOuterTypeAnnotation ty
)
(lsVars ls)
}
)
pure $ PVar (ty, dropForType ty) ident
addLetBinding (PBox ty pat) =
PBox (ty, Just DropMe) <$> addLetBinding pat
PBox (ty, dropForType ty) <$> addLetBinding pat
addLetBinding (PTuple ty p ps) = do
PTuple (ty, Just DropMe)
PTuple (ty, dropForType ty )
<$> addLetBinding p
<*> traverse addLetBinding ps

dropForType :: Type ann -> Maybe (Drops an)
dropForType ty = if isPrimitive ty then Nothing else Just DropMe

decorate ::
(Show ann) =>
(MonadState (LinearState ann) m, MonadWriter (M.Map Identifier (Type ann)) m) =>
Expand Down
18 changes: 9 additions & 9 deletions wasm-calc9/test/Test/Linearity/LinearitySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ spec = do
)
(dVar "a")
),
("function allocUnused() -> Int64 { let _ = Box((1: Int32)); 22 }",
ELet Nothing (PVar (Just DropMe) "_fresh_name1") (EBox Nothing (EAnn Nothing dTyInt32 (dInt 1))) (dInt 22)),

( "function incrementallyDropBoxesAfterUse() -> Int64 { let Box(outer) = Box(Box((100: Int64))); let Box(inner) = outer; inner }",
ELet
Nothing
Expand Down Expand Up @@ -152,28 +155,28 @@ spec = do
[ ( "function sum (a: Int64, b: Int64) -> Int64 { a + b }",
LinearState
{ lsVars =
M.fromList [("a", (LTPrimitive, ())), ("b", (LTPrimitive, ()))],
M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))],
lsUses = [("b", Whole ()), ("a", Whole ())],
lsFresh = 0
}
),
( "function pair<a,b>(a: a, b: b) -> (a,b) { (a,b) }",
LinearState
{ lsVars = M.fromList [("a", (LTBoxed, ())), ("b", (LTBoxed, ()))],
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
lsUses = [("b", Whole ()), ("a", Whole ())],
lsFresh = 0
}
),
( "function dontUseA<a,b>(a: a, b: b) -> b { b }",
LinearState
{ lsVars = M.fromList [("a", (LTBoxed, ())), ("b", (LTBoxed, ()))],
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
lsUses = [("b", Whole ())],
lsFresh = 0
}
),
( "function dup<a>(a: a) -> (a,a) { (a,a)}",
LinearState
{ lsVars = M.fromList [("a", (LTBoxed, ()))],
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))],
lsUses = [("a", Whole ()), ("a", Whole ())],
lsFresh = 0
}
Expand Down Expand Up @@ -214,18 +217,15 @@ spec = do

describe "expected failures" $ do
let failures =
[ {-( "function dontUseA<a,b>(a: a, b: b) -> b { b }",
[ ( "function dontUseA<a,b>(a: a, b: b) -> b { b }",
NotUsed () "a"
),
( "function dontUsePrimA(a: Int64, b: Int64) -> Int64 { b }",
NotUsed () "a"
),-}
),
( "function dup<a>(a: a) -> (a,a) { (a,a)}",
UsedMultipleTimes [(), ()] "a"
),
{-( "function twice(pair: (Int64, Int64)) { pair.1 + pair.2 }",
UsedMultipleTimes "pair"
),-}
( "function withPair<a,b>(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }",
UsedMultipleTimes [(), ()] "a"
)
Expand Down

0 comments on commit 6b33be5

Please sign in to comment.