diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs
index 23c2a049..bfd3ce5b 100644
--- a/wasm-calc11/src/Calc/Linearity/Decorate.hs
+++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs
@@ -20,6 +20,7 @@ import Control.Monad (unless)
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (second)
+import Data.Foldable (traverse_)
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
@@ -32,6 +33,40 @@ getFresh = do
modify (\ls -> ls {lsFresh = lsFresh ls + 1})
gets lsFresh
+-- | push a load of uses directly onto the head of the uses stack
+pushUses ::
+ (MonadState (LinearState ann) m) =>
+ M.Map Identifier (NE.NonEmpty (Linearity ann)) ->
+ m ()
+pushUses uses =
+ let pushForIdent ident =
+ traverse_ (\(Whole ann) -> recordUsesInState ident ann)
+ in traverse_ (uncurry pushForIdent) (M.toList uses)
+
+mapHead :: (a -> a) -> NE.NonEmpty a -> NE.NonEmpty a
+mapHead f (neHead NE.:| neTail) =
+ f neHead NE.:| neTail
+
+recordUsesInState ::
+ (MonadState (LinearState ann) m) =>
+ Identifier ->
+ ann ->
+ m ()
+recordUsesInState ident ann =
+ modify
+ ( \ls ->
+ let f =
+ M.alter
+ ( \existing ->
+ let newItem = Whole ann
+ in Just $ case existing of
+ Just neExisting -> newItem NE.:| NE.toList neExisting
+ Nothing -> NE.singleton newItem
+ )
+ ident
+ in ls {lsUses = mapHead f (lsUses ls)}
+ )
+
recordUse ::
( MonadState (LinearState ann) m,
MonadWriter (M.Map Identifier (Type ann)) m
@@ -40,9 +75,25 @@ recordUse ::
Type ann ->
m ()
recordUse ident ty = do
- modify (\ls -> ls {lsUses = (ident, Whole (getOuterTypeAnnotation ty)) : lsUses ls})
+ recordUsesInState ident (getOuterTypeAnnotation ty)
unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types
+-- run an action, giving it a new uses scope
+-- then chop off the new values and return them
+-- this allows us to dedupe and re-add them to the current stack as desired
+scoped :: (MonadState (LinearState ann) m) => m a -> m (a, M.Map Identifier (NE.NonEmpty (Linearity ann)))
+scoped action = do
+ -- add a new empty stack
+ modify (\ls -> ls {lsUses = mempty NE.:| NE.toList (lsUses ls)})
+ -- run the action, collecting uses in NE.head of uses stack
+ result <- action
+ -- grab the top level items
+ items <- gets (NE.head . lsUses)
+ -- bin them off stack
+ modify (\ls -> ls {lsUses = NE.fromList (NE.tail (lsUses ls))})
+ -- return both things
+ pure (result, items)
+
isPrimitive :: Type ann -> Bool
isPrimitive (TPrim {}) = True
isPrimitive _ = False
@@ -124,6 +175,9 @@ getVarsInScope = gets (S.fromList . mapMaybe userDefined . M.keys . lsVars)
UserDefined i -> Just i
_ -> Nothing
+combineWithBiggestItems :: (Ord k, Foldable t) => M.Map k (t a) -> M.Map k (t a) -> M.Map k (t a)
+combineWithBiggestItems = M.unionWith (\l r -> if length r > length l then r else l)
+
decorate ::
(Show ann) =>
( MonadState (LinearState ann) m,
@@ -153,23 +207,28 @@ decorate (EMatch ty expr pats) = do
-- for vars currently in scope outside the pattern arms
existingVars <- getVarsInScope
- -- need to work out a way of scoping variables created in patterns
- -- as they only exist in `patExpr`
let decoratePair (pat, patExpr) = do
(decoratedPat, _idents) <- decoratePattern pat
- (decoratedPatExpr, patIdents) <- runWriterT (decorate patExpr)
+ ((decoratedPatExpr, uses), patIdents) <- runWriterT (scoped (decorate patExpr))
-- we only care about idents that exist in the current scope
let usefulIdents =
M.filterWithKey (\k _ -> S.member k existingVars) patIdents
- pure (usefulIdents, (decoratedPat, decoratedPatExpr))
+ pure ((usefulIdents, uses), (decoratedPat, decoratedPatExpr))
decoratedPatterns <- traverse decoratePair pats
- let allIdents = foldMap fst decoratedPatterns
+ let allIdents = foldMap (fst . fst) decoratedPatterns
+
+ let allUses = snd . fst <$> decoratedPatterns
+ combinedUses = foldr combineWithBiggestItems (NE.head allUses) (NE.tail allUses)
+
+ -- here we're gonna go through each constructor and keep the longest list of
+ -- things, then push the ones we want to keep hold of
+ pushUses combinedUses
-- now we know all the idents, we can decorate each pattern with the ones
-- it's missing
- let decorateWithIdents (idents, (pat, patExpr)) =
+ let decorateWithIdents ((idents, _), (pat, patExpr)) =
let dropIdents = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference allIdents idents))
in (pat, mapOuterExprAnnotation (second (const dropIdents)) patExpr)
@@ -179,8 +238,13 @@ decorate (EMatch ty expr pats) = do
decorate (EInfix ty op a b) =
EInfix (ty, Nothing) op <$> decorate a <*> decorate b
decorate (EIf ty predExpr thenExpr elseExpr) = do
- (decoratedThen, thenIdents) <- runWriterT (decorate thenExpr)
- (decoratedElse, elseIdents) <- runWriterT (decorate elseExpr)
+ ((decoratedThen, thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr))
+ ((decoratedElse, elseUses), elseIdents) <- runWriterT (scoped (decorate elseExpr))
+
+ -- here we're gonna go through each constructor and keep the longest list of
+ -- things, then push the ones we want to keep hold of
+ pushUses
+ (combineWithBiggestItems thenUses elseUses)
-- work out idents used in the other branch but not this one
let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents))
diff --git a/wasm-calc11/src/Calc/Linearity/Error.hs b/wasm-calc11/src/Calc/Linearity/Error.hs
index 4dbf5c13..46eeb9a8 100644
--- a/wasm-calc11/src/Calc/Linearity/Error.hs
+++ b/wasm-calc11/src/Calc/Linearity/Error.hs
@@ -11,6 +11,7 @@ where
import Calc.SourceSpan
import Calc.Types.Annotation
import Calc.Types.Identifier
+import qualified Data.List.NonEmpty as NE
import Data.Maybe (catMaybes, mapMaybe)
import qualified Data.Text as T
import qualified Error.Diagnose as Diag
@@ -19,7 +20,7 @@ import qualified Prettyprinter.Render.Text as PP
data LinearityError ann
= NotUsed ann Identifier
- | UsedMultipleTimes [ann] Identifier
+ | UsedMultipleTimes (NE.NonEmpty ann) Identifier
deriving stock (Eq, Ord, Show)
prettyPrint :: PP.Doc doc -> T.Text
@@ -83,7 +84,7 @@ linearityErrorDiagnostic input e =
)
)
)
- anns
+ (NE.toList anns)
)
[]
in Diag.addReport diag report
diff --git a/wasm-calc11/src/Calc/Linearity/Types.hs b/wasm-calc11/src/Calc/Linearity/Types.hs
index 6e643568..333f2376 100644
--- a/wasm-calc11/src/Calc/Linearity/Types.hs
+++ b/wasm-calc11/src/Calc/Linearity/Types.hs
@@ -39,7 +39,7 @@ data UserDefined a = UserDefined a | Internal a
data LinearState ann = LinearState
{ lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann),
- lsUses :: [(Identifier, Linearity ann)],
+ lsUses :: NE.NonEmpty (M.Map Identifier (NE.NonEmpty (Linearity ann))),
lsFresh :: Natural
}
deriving stock (Eq, Ord, Show, Functor)
diff --git a/wasm-calc11/src/Calc/Linearity/Validate.hs b/wasm-calc11/src/Calc/Linearity/Validate.hs
index 9133833a..558689d0 100644
--- a/wasm-calc11/src/Calc/Linearity/Validate.hs
+++ b/wasm-calc11/src/Calc/Linearity/Validate.hs
@@ -24,6 +24,7 @@ import Control.Monad.State
import Control.Monad.Writer
import Data.Foldable (traverse_)
import Data.Functor (($>))
+import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
getLinearityAnnotation :: Linearity ann -> ann
@@ -38,50 +39,37 @@ validateGlobal ::
(Show ann) =>
Global (Type ann) ->
Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann)))
-validateGlobal glob =
+validateGlobal glob = do
let (expr, linearState) = getGlobalUses glob
- in validate linearState $> expr
+ validate linearState $> expr
validateFunction ::
(Show ann) =>
Function (Type ann) ->
Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann)))
-validateFunction fn =
+validateFunction fn = do
let (expr, linearState) = getFunctionUses fn
- in validate linearState $> expr
+ validate linearState $> expr
validate :: LinearState ann -> Either (LinearityError ann) ()
validate (LinearState {lsVars, lsUses}) =
let validateFunctionItem (Internal _, _) = Right ()
validateFunctionItem (UserDefined ident, (linearity, ann)) =
- let completeUses = filterCompleteUses lsUses ident
+ let completeUses = maybe mempty NE.toList (M.lookup ident (NE.head lsUses))
in case linearity of
LTPrimitive ->
if null completeUses
then Left (NotUsed ann ident)
else Right ()
LTBoxed ->
- case length completeUses of
- 0 -> Left (NotUsed ann ident)
- 1 -> Right ()
- _more ->
- Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident)
+ case NE.nonEmpty completeUses of
+ Nothing -> Left (NotUsed ann ident)
+ Just neUses ->
+ if length neUses == 1
+ then Right ()
+ else Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident)
in traverse_ validateFunctionItem (M.toList lsVars)
--- | count uses of a given identifier
-filterCompleteUses ::
- [(Identifier, Linearity ann)] ->
- Identifier ->
- [Linearity ann]
-filterCompleteUses uses ident =
- foldr
- ( \(thisIdent, linearity) total -> case linearity of
- Whole _ ->
- if thisIdent == ident then linearity : total else total
- )
- []
- uses
-
getFunctionUses ::
(Show ann) =>
Function (Type ann) ->
@@ -94,7 +82,7 @@ getFunctionUses (Function {fnBody, fnArgs}) =
initialState =
LinearState
{ lsVars = initialVars,
- lsUses = mempty,
+ lsUses = NE.singleton mempty,
lsFresh = 0
}
@@ -119,6 +107,6 @@ getGlobalUses (Global {glbExpr}) =
initialState =
LinearState
{ lsVars = mempty,
- lsUses = mempty,
+ lsUses = NE.singleton mempty,
lsFresh = 0
}
diff --git a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs
index cf859e44..b117ef09 100644
--- a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs
+++ b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs
@@ -241,28 +241,34 @@ spec = do
LinearState
{ lsVars =
M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))],
- lsUses = [("b", Whole ()), ("a", Whole ())],
+ lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), ("a", NE.singleton $ Whole ())]),
lsFresh = 0
}
),
( "function pair(a: a, b: b) -> (a,b) { (a,b) }",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
- lsUses = [("b", Whole ()), ("a", Whole ())],
+ lsUses =
+ NE.singleton
+ ( M.fromList
+ [ ("b", NE.singleton $ Whole ()),
+ ("a", NE.singleton $ Whole ())
+ ]
+ ),
lsFresh = 0
}
),
( "function dontUseA(a: a, b: b) -> b { b }",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
- lsUses = [("b", Whole ())],
+ lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ())]),
lsFresh = 0
}
),
( "function dup(a: a) -> (a,a) { (a,a)}",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))],
- lsUses = [("a", Whole ()), ("a", Whole ())],
+ lsUses = NE.singleton (M.fromList [("a", NE.fromList [Whole (), Whole ()])]),
lsFresh = 0
}
)
@@ -286,7 +292,9 @@ spec = do
"function pair(a: a, b: b) -> (a,b) { (a,b) }",
"function addPair(pair: (Int64,Int64)) -> Int64 { let (a,b) = pair; a + b }",
"function fst(pair: (a,b)) -> Box(a) { let (a,_) = pair; Box(a) }",
- "function main() -> Int64 { let _ = (1: Int64); 2 }"
+ "function main() -> Int64 { let _ = (1: Int64); 2 }",
+ "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then pair else pair }",
+ "function bothSidesOfMatch() -> (Boolean,Boolean) { let pair = (True,False); case True { True -> pair, False -> pair } }"
]
traverse_
( \str -> it (T.unpack str) $ do
@@ -309,10 +317,13 @@ spec = do
NotUsed () "a"
),
( "function dup(a: a) -> (a,a) { (a,a)}",
- UsedMultipleTimes [(), ()] "a"
+ UsedMultipleTimes (NE.fromList [(), ()]) "a"
),
( "function withPair(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }",
- UsedMultipleTimes [(), ()] "a"
+ UsedMultipleTimes (NE.fromList [(), ()]) "a"
+ ),
+ ( "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then { let _ = pair; pair } else pair }",
+ UsedMultipleTimes (NE.fromList [(), ()]) "pair"
)
]
traverse_