diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs index 23c2a049..bfd3ce5b 100644 --- a/wasm-calc11/src/Calc/Linearity/Decorate.hs +++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs @@ -20,6 +20,7 @@ import Control.Monad (unless) import Control.Monad.State import Control.Monad.Writer import Data.Bifunctor (second) +import Data.Foldable (traverse_) import qualified Data.List.NonEmpty as NE import qualified Data.Map as M import Data.Maybe (mapMaybe) @@ -32,6 +33,40 @@ getFresh = do modify (\ls -> ls {lsFresh = lsFresh ls + 1}) gets lsFresh +-- | push a load of uses directly onto the head of the uses stack +pushUses :: + (MonadState (LinearState ann) m) => + M.Map Identifier (NE.NonEmpty (Linearity ann)) -> + m () +pushUses uses = + let pushForIdent ident = + traverse_ (\(Whole ann) -> recordUsesInState ident ann) + in traverse_ (uncurry pushForIdent) (M.toList uses) + +mapHead :: (a -> a) -> NE.NonEmpty a -> NE.NonEmpty a +mapHead f (neHead NE.:| neTail) = + f neHead NE.:| neTail + +recordUsesInState :: + (MonadState (LinearState ann) m) => + Identifier -> + ann -> + m () +recordUsesInState ident ann = + modify + ( \ls -> + let f = + M.alter + ( \existing -> + let newItem = Whole ann + in Just $ case existing of + Just neExisting -> newItem NE.:| NE.toList neExisting + Nothing -> NE.singleton newItem + ) + ident + in ls {lsUses = mapHead f (lsUses ls)} + ) + recordUse :: ( MonadState (LinearState ann) m, MonadWriter (M.Map Identifier (Type ann)) m @@ -40,9 +75,25 @@ recordUse :: Type ann -> m () recordUse ident ty = do - modify (\ls -> ls {lsUses = (ident, Whole (getOuterTypeAnnotation ty)) : lsUses ls}) + recordUsesInState ident (getOuterTypeAnnotation ty) unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types +-- run an action, giving it a new uses scope +-- then chop off the new values and return them +-- this allows us to dedupe and re-add them to the current stack as desired +scoped :: (MonadState (LinearState ann) m) => m a -> m (a, M.Map Identifier (NE.NonEmpty (Linearity ann))) +scoped action = do + -- add a new empty stack + modify (\ls -> ls {lsUses = mempty NE.:| NE.toList (lsUses ls)}) + -- run the action, collecting uses in NE.head of uses stack + result <- action + -- grab the top level items + items <- gets (NE.head . lsUses) + -- bin them off stack + modify (\ls -> ls {lsUses = NE.fromList (NE.tail (lsUses ls))}) + -- return both things + pure (result, items) + isPrimitive :: Type ann -> Bool isPrimitive (TPrim {}) = True isPrimitive _ = False @@ -124,6 +175,9 @@ getVarsInScope = gets (S.fromList . mapMaybe userDefined . M.keys . lsVars) UserDefined i -> Just i _ -> Nothing +combineWithBiggestItems :: (Ord k, Foldable t) => M.Map k (t a) -> M.Map k (t a) -> M.Map k (t a) +combineWithBiggestItems = M.unionWith (\l r -> if length r > length l then r else l) + decorate :: (Show ann) => ( MonadState (LinearState ann) m, @@ -153,23 +207,28 @@ decorate (EMatch ty expr pats) = do -- for vars currently in scope outside the pattern arms existingVars <- getVarsInScope - -- need to work out a way of scoping variables created in patterns - -- as they only exist in `patExpr` let decoratePair (pat, patExpr) = do (decoratedPat, _idents) <- decoratePattern pat - (decoratedPatExpr, patIdents) <- runWriterT (decorate patExpr) + ((decoratedPatExpr, uses), patIdents) <- runWriterT (scoped (decorate patExpr)) -- we only care about idents that exist in the current scope let usefulIdents = M.filterWithKey (\k _ -> S.member k existingVars) patIdents - pure (usefulIdents, (decoratedPat, decoratedPatExpr)) + pure ((usefulIdents, uses), (decoratedPat, decoratedPatExpr)) decoratedPatterns <- traverse decoratePair pats - let allIdents = foldMap fst decoratedPatterns + let allIdents = foldMap (fst . fst) decoratedPatterns + + let allUses = snd . fst <$> decoratedPatterns + combinedUses = foldr combineWithBiggestItems (NE.head allUses) (NE.tail allUses) + + -- here we're gonna go through each constructor and keep the longest list of + -- things, then push the ones we want to keep hold of + pushUses combinedUses -- now we know all the idents, we can decorate each pattern with the ones -- it's missing - let decorateWithIdents (idents, (pat, patExpr)) = + let decorateWithIdents ((idents, _), (pat, patExpr)) = let dropIdents = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference allIdents idents)) in (pat, mapOuterExprAnnotation (second (const dropIdents)) patExpr) @@ -179,8 +238,13 @@ decorate (EMatch ty expr pats) = do decorate (EInfix ty op a b) = EInfix (ty, Nothing) op <$> decorate a <*> decorate b decorate (EIf ty predExpr thenExpr elseExpr) = do - (decoratedThen, thenIdents) <- runWriterT (decorate thenExpr) - (decoratedElse, elseIdents) <- runWriterT (decorate elseExpr) + ((decoratedThen, thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr)) + ((decoratedElse, elseUses), elseIdents) <- runWriterT (scoped (decorate elseExpr)) + + -- here we're gonna go through each constructor and keep the longest list of + -- things, then push the ones we want to keep hold of + pushUses + (combineWithBiggestItems thenUses elseUses) -- work out idents used in the other branch but not this one let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents)) diff --git a/wasm-calc11/src/Calc/Linearity/Error.hs b/wasm-calc11/src/Calc/Linearity/Error.hs index 4dbf5c13..46eeb9a8 100644 --- a/wasm-calc11/src/Calc/Linearity/Error.hs +++ b/wasm-calc11/src/Calc/Linearity/Error.hs @@ -11,6 +11,7 @@ where import Calc.SourceSpan import Calc.Types.Annotation import Calc.Types.Identifier +import qualified Data.List.NonEmpty as NE import Data.Maybe (catMaybes, mapMaybe) import qualified Data.Text as T import qualified Error.Diagnose as Diag @@ -19,7 +20,7 @@ import qualified Prettyprinter.Render.Text as PP data LinearityError ann = NotUsed ann Identifier - | UsedMultipleTimes [ann] Identifier + | UsedMultipleTimes (NE.NonEmpty ann) Identifier deriving stock (Eq, Ord, Show) prettyPrint :: PP.Doc doc -> T.Text @@ -83,7 +84,7 @@ linearityErrorDiagnostic input e = ) ) ) - anns + (NE.toList anns) ) [] in Diag.addReport diag report diff --git a/wasm-calc11/src/Calc/Linearity/Types.hs b/wasm-calc11/src/Calc/Linearity/Types.hs index 6e643568..333f2376 100644 --- a/wasm-calc11/src/Calc/Linearity/Types.hs +++ b/wasm-calc11/src/Calc/Linearity/Types.hs @@ -39,7 +39,7 @@ data UserDefined a = UserDefined a | Internal a data LinearState ann = LinearState { lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann), - lsUses :: [(Identifier, Linearity ann)], + lsUses :: NE.NonEmpty (M.Map Identifier (NE.NonEmpty (Linearity ann))), lsFresh :: Natural } deriving stock (Eq, Ord, Show, Functor) diff --git a/wasm-calc11/src/Calc/Linearity/Validate.hs b/wasm-calc11/src/Calc/Linearity/Validate.hs index 9133833a..558689d0 100644 --- a/wasm-calc11/src/Calc/Linearity/Validate.hs +++ b/wasm-calc11/src/Calc/Linearity/Validate.hs @@ -24,6 +24,7 @@ import Control.Monad.State import Control.Monad.Writer import Data.Foldable (traverse_) import Data.Functor (($>)) +import qualified Data.List.NonEmpty as NE import qualified Data.Map as M getLinearityAnnotation :: Linearity ann -> ann @@ -38,50 +39,37 @@ validateGlobal :: (Show ann) => Global (Type ann) -> Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann))) -validateGlobal glob = +validateGlobal glob = do let (expr, linearState) = getGlobalUses glob - in validate linearState $> expr + validate linearState $> expr validateFunction :: (Show ann) => Function (Type ann) -> Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann))) -validateFunction fn = +validateFunction fn = do let (expr, linearState) = getFunctionUses fn - in validate linearState $> expr + validate linearState $> expr validate :: LinearState ann -> Either (LinearityError ann) () validate (LinearState {lsVars, lsUses}) = let validateFunctionItem (Internal _, _) = Right () validateFunctionItem (UserDefined ident, (linearity, ann)) = - let completeUses = filterCompleteUses lsUses ident + let completeUses = maybe mempty NE.toList (M.lookup ident (NE.head lsUses)) in case linearity of LTPrimitive -> if null completeUses then Left (NotUsed ann ident) else Right () LTBoxed -> - case length completeUses of - 0 -> Left (NotUsed ann ident) - 1 -> Right () - _more -> - Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident) + case NE.nonEmpty completeUses of + Nothing -> Left (NotUsed ann ident) + Just neUses -> + if length neUses == 1 + then Right () + else Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident) in traverse_ validateFunctionItem (M.toList lsVars) --- | count uses of a given identifier -filterCompleteUses :: - [(Identifier, Linearity ann)] -> - Identifier -> - [Linearity ann] -filterCompleteUses uses ident = - foldr - ( \(thisIdent, linearity) total -> case linearity of - Whole _ -> - if thisIdent == ident then linearity : total else total - ) - [] - uses - getFunctionUses :: (Show ann) => Function (Type ann) -> @@ -94,7 +82,7 @@ getFunctionUses (Function {fnBody, fnArgs}) = initialState = LinearState { lsVars = initialVars, - lsUses = mempty, + lsUses = NE.singleton mempty, lsFresh = 0 } @@ -119,6 +107,6 @@ getGlobalUses (Global {glbExpr}) = initialState = LinearState { lsVars = mempty, - lsUses = mempty, + lsUses = NE.singleton mempty, lsFresh = 0 } diff --git a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs index cf859e44..b117ef09 100644 --- a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs +++ b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs @@ -241,28 +241,34 @@ spec = do LinearState { lsVars = M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))], - lsUses = [("b", Whole ()), ("a", Whole ())], + lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), ("a", NE.singleton $ Whole ())]), lsFresh = 0 } ), ( "function pair(a: a, b: b) -> (a,b) { (a,b) }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = [("b", Whole ()), ("a", Whole ())], + lsUses = + NE.singleton + ( M.fromList + [ ("b", NE.singleton $ Whole ()), + ("a", NE.singleton $ Whole ()) + ] + ), lsFresh = 0 } ), ( "function dontUseA(a: a, b: b) -> b { b }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = [("b", Whole ())], + lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ())]), lsFresh = 0 } ), ( "function dup(a: a) -> (a,a) { (a,a)}", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))], - lsUses = [("a", Whole ()), ("a", Whole ())], + lsUses = NE.singleton (M.fromList [("a", NE.fromList [Whole (), Whole ()])]), lsFresh = 0 } ) @@ -286,7 +292,9 @@ spec = do "function pair(a: a, b: b) -> (a,b) { (a,b) }", "function addPair(pair: (Int64,Int64)) -> Int64 { let (a,b) = pair; a + b }", "function fst(pair: (a,b)) -> Box(a) { let (a,_) = pair; Box(a) }", - "function main() -> Int64 { let _ = (1: Int64); 2 }" + "function main() -> Int64 { let _ = (1: Int64); 2 }", + "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then pair else pair }", + "function bothSidesOfMatch() -> (Boolean,Boolean) { let pair = (True,False); case True { True -> pair, False -> pair } }" ] traverse_ ( \str -> it (T.unpack str) $ do @@ -309,10 +317,13 @@ spec = do NotUsed () "a" ), ( "function dup(a: a) -> (a,a) { (a,a)}", - UsedMultipleTimes [(), ()] "a" + UsedMultipleTimes (NE.fromList [(), ()]) "a" ), ( "function withPair(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }", - UsedMultipleTimes [(), ()] "a" + UsedMultipleTimes (NE.fromList [(), ()]) "a" + ), + ( "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then { let _ = pair; pair } else pair }", + UsedMultipleTimes (NE.fromList [(), ()]) "pair" ) ] traverse_