From a16aa78e070329a98724ae972fb5152cdd917239 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Thu, 15 Aug 2024 17:52:00 +0100 Subject: [PATCH 1/4] Add a stack structure --- wasm-calc11/src/Calc/Linearity/Decorate.hs | 49 +++++++++++++++++-- wasm-calc11/src/Calc/Linearity/Types.hs | 3 +- wasm-calc11/src/Calc/Linearity/Validate.hs | 19 +++---- .../test/Test/Linearity/LinearitySpec.hs | 19 ++++--- 4 files changed, 70 insertions(+), 20 deletions(-) diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs index 23c2a049..4e20bc5a 100644 --- a/wasm-calc11/src/Calc/Linearity/Decorate.hs +++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs @@ -9,6 +9,7 @@ module Calc.Linearity.Decorate ) where +import Debug.Trace import Calc.ExprUtils import Calc.Linearity.Types import Calc.TypeUtils @@ -32,6 +33,16 @@ getFresh = do modify (\ls -> ls {lsFresh = lsFresh ls + 1}) gets lsFresh +-- | push a load of uses directly onto the head of the uses stack +pushUses :: (MonadState (LinearState ann) m) => + [(Identifier, Linearity ann)] -> m () +pushUses uses = + modify (\ls -> + let (topOfStack NE.:| restOfStack) = lsUses ls + newTopOfStack = topOfStack <> uses + in + ls {lsUses = newTopOfStack NE.:| restOfStack }) + recordUse :: ( MonadState (LinearState ann) m, MonadWriter (M.Map Identifier (Type ann)) m @@ -40,9 +51,30 @@ recordUse :: Type ann -> m () recordUse ident ty = do - modify (\ls -> ls {lsUses = (ident, Whole (getOuterTypeAnnotation ty)) : lsUses ls}) + modify (\ls -> + let (topOfStack NE.:| restOfStack) = lsUses ls + newTopOfStack = (ident, Whole (getOuterTypeAnnotation ty)) : topOfStack + in + ls {lsUses = newTopOfStack NE.:| restOfStack }) unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types +-- run an action, giving it a new uses scope +-- then chop off the new values and return them +-- this allows us to dedupe and re-add them to the current stack as desired +scoped :: (MonadState (LinearState ann) m) => m a -> m (a, [(Identifier, Linearity ann)]) +scoped action = do + -- add a new empty stack + modify (\ls -> ls { lsUses = mempty NE.:| (NE.toList $ lsUses ls) }) + -- run the action, collecting uses in NE.head of uses stack + result <- action + -- grab the top level items + items <- gets (NE.head . lsUses) + -- bin them off stack + modify (\ls -> ls { lsUses = NE.fromList (NE.tail (lsUses ls)) }) + -- return both things + pure (result, items) + + isPrimitive :: Type ann -> Bool isPrimitive (TPrim {}) = True isPrimitive _ = False @@ -179,8 +211,19 @@ decorate (EMatch ty expr pats) = do decorate (EInfix ty op a b) = EInfix (ty, Nothing) op <$> decorate a <*> decorate b decorate (EIf ty predExpr thenExpr elseExpr) = do - (decoratedThen, thenIdents) <- runWriterT (decorate thenExpr) - (decoratedElse, elseIdents) <- runWriterT (decorate elseExpr) + ((decoratedThen,thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr)) + ((decoratedElse, elseUses), elseIdents) <- runWriterT (scoped (decorate elseExpr)) + + traceShowM ("thenUses" :: String, thenUses) + traceShowM ("elseUses" :: String, elseUses) + + -- here we're gonna bin off duplicates and stuff + let usesToKeep = thenUses <> elseUses + + traceShowM ("usesToKeep" :: String, usesToKeep) + + -- push the ones we want to keep hold of + pushUses usesToKeep -- work out idents used in the other branch but not this one let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents)) diff --git a/wasm-calc11/src/Calc/Linearity/Types.hs b/wasm-calc11/src/Calc/Linearity/Types.hs index 6e643568..438558e6 100644 --- a/wasm-calc11/src/Calc/Linearity/Types.hs +++ b/wasm-calc11/src/Calc/Linearity/Types.hs @@ -39,7 +39,8 @@ data UserDefined a = UserDefined a | Internal a data LinearState ann = LinearState { lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann), - lsUses :: [(Identifier, Linearity ann)], + lsUses :: NE.NonEmpty [(Identifier, Linearity ann)], lsFresh :: Natural } deriving stock (Eq, Ord, Show, Functor) + diff --git a/wasm-calc11/src/Calc/Linearity/Validate.hs b/wasm-calc11/src/Calc/Linearity/Validate.hs index 9133833a..5b9b00d6 100644 --- a/wasm-calc11/src/Calc/Linearity/Validate.hs +++ b/wasm-calc11/src/Calc/Linearity/Validate.hs @@ -9,6 +9,7 @@ module Calc.Linearity.Validate ) where +import qualified Data.List.NonEmpty as NE import Calc.Linearity.Decorate import Calc.Linearity.Error import Calc.Linearity.Types @@ -38,23 +39,23 @@ validateGlobal :: (Show ann) => Global (Type ann) -> Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann))) -validateGlobal glob = - let (expr, linearState) = getGlobalUses glob - in validate linearState $> expr +validateGlobal glob = do + let (expr, linearState) = getGlobalUses glob + validate linearState $> expr validateFunction :: (Show ann) => Function (Type ann) -> Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann))) -validateFunction fn = +validateFunction fn = do let (expr, linearState) = getFunctionUses fn - in validate linearState $> expr + validate linearState $> expr validate :: LinearState ann -> Either (LinearityError ann) () validate (LinearState {lsVars, lsUses}) = let validateFunctionItem (Internal _, _) = Right () validateFunctionItem (UserDefined ident, (linearity, ann)) = - let completeUses = filterCompleteUses lsUses ident + let completeUses = filterCompleteUses (NE.head lsUses) ident in case linearity of LTPrimitive -> if null completeUses @@ -94,7 +95,7 @@ getFunctionUses (Function {fnBody, fnArgs}) = initialState = LinearState { lsVars = initialVars, - lsUses = mempty, + lsUses = NE.singleton mempty, lsFresh = 0 } @@ -110,7 +111,7 @@ getFunctionUses (Function {fnBody, fnArgs}) = getGlobalUses :: (Show ann) => Global (Type ann) -> - (Expr (Type ann, Maybe (Drops ann)), LinearState ann) + (Expr (Type ann, Maybe (Drops ann)), LinearState ann) getGlobalUses (Global {glbExpr}) = fst $ runIdentity $ runWriterT $ runStateT action initialState where @@ -119,6 +120,6 @@ getGlobalUses (Global {glbExpr}) = initialState = LinearState { lsVars = mempty, - lsUses = mempty, + lsUses = NE.singleton mempty, lsFresh = 0 } diff --git a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs index cf859e44..c3a8ce88 100644 --- a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs +++ b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs @@ -241,28 +241,28 @@ spec = do LinearState { lsVars = M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))], - lsUses = [("b", Whole ()), ("a", Whole ())], + lsUses = NE.singleton [("b", Whole ()), ("a", Whole ())], lsFresh = 0 } ), ( "function pair(a: a, b: b) -> (a,b) { (a,b) }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = [("b", Whole ()), ("a", Whole ())], + lsUses = NE.singleton [("b", Whole ()), ("a", Whole ())], lsFresh = 0 } ), ( "function dontUseA(a: a, b: b) -> b { b }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = [("b", Whole ())], + lsUses = NE.singleton [("b", Whole ())], lsFresh = 0 } ), ( "function dup(a: a) -> (a,a) { (a,a)}", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))], - lsUses = [("a", Whole ()), ("a", Whole ())], + lsUses = NE.singleton [("a", Whole ()), ("a", Whole ())], lsFresh = 0 } ) @@ -280,13 +280,15 @@ spec = do strings describe "validateFunction" $ do - describe "expected successes" $ do + fdescribe "expected successes" $ do let success = [ "function sum (a: Int64, b: Int64) -> Int64 { a + b }", "function pair(a: a, b: b) -> (a,b) { (a,b) }", "function addPair(pair: (Int64,Int64)) -> Int64 { let (a,b) = pair; a + b }", "function fst(pair: (a,b)) -> Box(a) { let (a,_) = pair; Box(a) }", - "function main() -> Int64 { let _ = (1: Int64); 2 }" + "function main() -> Int64 { let _ = (1: Int64); 2 }", + "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then pair else pair }", + "function bothSidesOfMatch() -> (Boolean,Boolean) { let pair = (True,False); case True { True -> pair, False -> pair } }" ] traverse_ ( \str -> it (T.unpack str) $ do @@ -300,7 +302,7 @@ spec = do ) success - describe "expected failures" $ do + fdescribe "expected failures" $ do let failures = [ ( "function dontUseA(a: a, b: b) -> b { b }", NotUsed () "a" @@ -313,6 +315,9 @@ spec = do ), ( "function withPair(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }", UsedMultipleTimes [(), ()] "a" + ), + ( "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then { let _ = pair; pair } else pair }", + UsedMultipleTimes [(), ()] "pair" ) ] traverse_ From 3eab14da32ae3ce7b28a38002bef098b975f4d34 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Fri, 16 Aug 2024 17:51:54 +0100 Subject: [PATCH 2/4] Make the data structures better for the problem --- wasm-calc11/src/Calc/Linearity/Decorate.hs | 37 +++++++++++++------ wasm-calc11/src/Calc/Linearity/Error.hs | 5 ++- wasm-calc11/src/Calc/Linearity/Types.hs | 2 +- wasm-calc11/src/Calc/Linearity/Validate.hs | 28 ++++---------- .../test/Test/Linearity/LinearitySpec.hs | 19 +++++----- 5 files changed, 48 insertions(+), 43 deletions(-) diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs index 4e20bc5a..f8811100 100644 --- a/wasm-calc11/src/Calc/Linearity/Decorate.hs +++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs @@ -9,6 +9,7 @@ module Calc.Linearity.Decorate ) where +import Data.Foldable (traverse_) import Debug.Trace import Calc.ExprUtils import Calc.Linearity.Types @@ -35,13 +36,29 @@ getFresh = do -- | push a load of uses directly onto the head of the uses stack pushUses :: (MonadState (LinearState ann) m) => - [(Identifier, Linearity ann)] -> m () + M.Map Identifier (NE.NonEmpty (Linearity ann)) -> m () pushUses uses = + let pushForIdent ident items + = traverse_ (\(Whole ann) -> recordUsesInState ident ann) items + in traverse_ (uncurry pushForIdent) (M.toList uses) + +mapHead :: (a -> a) -> NE.NonEmpty a -> NE.NonEmpty a +mapHead f (neHead NE.:| neTail) = + (f neHead) NE.:| neTail + +recordUsesInState :: (MonadState (LinearState ann) m) => + Identifier -> ann -> m() +recordUsesInState ident ann = modify (\ls -> - let (topOfStack NE.:| restOfStack) = lsUses ls - newTopOfStack = topOfStack <> uses + let f = + M.alter (\existing -> + let newItem = Whole ann + in Just $ case existing of + Just neExisting -> newItem NE.:| (NE.toList neExisting) + Nothing -> NE.singleton newItem) ident in - ls {lsUses = newTopOfStack NE.:| restOfStack }) + ls {lsUses = mapHead f (lsUses ls) }) + recordUse :: ( MonadState (LinearState ann) m, @@ -51,17 +68,13 @@ recordUse :: Type ann -> m () recordUse ident ty = do - modify (\ls -> - let (topOfStack NE.:| restOfStack) = lsUses ls - newTopOfStack = (ident, Whole (getOuterTypeAnnotation ty)) : topOfStack - in - ls {lsUses = newTopOfStack NE.:| restOfStack }) + recordUsesInState ident (getOuterTypeAnnotation ty) unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types -- run an action, giving it a new uses scope -- then chop off the new values and return them -- this allows us to dedupe and re-add them to the current stack as desired -scoped :: (MonadState (LinearState ann) m) => m a -> m (a, [(Identifier, Linearity ann)]) +scoped :: (MonadState (LinearState ann) m) => m a -> m (a, (M.Map Identifier (NE.NonEmpty (Linearity ann)))) scoped action = do -- add a new empty stack modify (\ls -> ls { lsUses = mempty NE.:| (NE.toList $ lsUses ls) }) @@ -217,8 +230,10 @@ decorate (EIf ty predExpr thenExpr elseExpr) = do traceShowM ("thenUses" :: String, thenUses) traceShowM ("elseUses" :: String, elseUses) - -- here we're gonna bin off duplicates and stuff + -- here we're gonna go through each constructor and keep the longest list of + -- things let usesToKeep = thenUses <> elseUses + _ <- error "here is the problem" traceShowM ("usesToKeep" :: String, usesToKeep) diff --git a/wasm-calc11/src/Calc/Linearity/Error.hs b/wasm-calc11/src/Calc/Linearity/Error.hs index 4dbf5c13..7e8f8f0e 100644 --- a/wasm-calc11/src/Calc/Linearity/Error.hs +++ b/wasm-calc11/src/Calc/Linearity/Error.hs @@ -16,10 +16,11 @@ import qualified Data.Text as T import qualified Error.Diagnose as Diag import qualified Prettyprinter as PP import qualified Prettyprinter.Render.Text as PP +import qualified Data.List.NonEmpty as NE data LinearityError ann = NotUsed ann Identifier - | UsedMultipleTimes [ann] Identifier + | UsedMultipleTimes (NE.NonEmpty ann) Identifier deriving stock (Eq, Ord, Show) prettyPrint :: PP.Doc doc -> T.Text @@ -83,7 +84,7 @@ linearityErrorDiagnostic input e = ) ) ) - anns + (NE.toList anns) ) [] in Diag.addReport diag report diff --git a/wasm-calc11/src/Calc/Linearity/Types.hs b/wasm-calc11/src/Calc/Linearity/Types.hs index 438558e6..4024684d 100644 --- a/wasm-calc11/src/Calc/Linearity/Types.hs +++ b/wasm-calc11/src/Calc/Linearity/Types.hs @@ -39,7 +39,7 @@ data UserDefined a = UserDefined a | Internal a data LinearState ann = LinearState { lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann), - lsUses :: NE.NonEmpty [(Identifier, Linearity ann)], + lsUses :: NE.NonEmpty (M.Map Identifier (NE.NonEmpty (Linearity ann))), lsFresh :: Natural } deriving stock (Eq, Ord, Show, Functor) diff --git a/wasm-calc11/src/Calc/Linearity/Validate.hs b/wasm-calc11/src/Calc/Linearity/Validate.hs index 5b9b00d6..d9de22cd 100644 --- a/wasm-calc11/src/Calc/Linearity/Validate.hs +++ b/wasm-calc11/src/Calc/Linearity/Validate.hs @@ -55,34 +55,22 @@ validate :: LinearState ann -> Either (LinearityError ann) () validate (LinearState {lsVars, lsUses}) = let validateFunctionItem (Internal _, _) = Right () validateFunctionItem (UserDefined ident, (linearity, ann)) = - let completeUses = filterCompleteUses (NE.head lsUses) ident + let completeUses = maybe mempty NE.toList (M.lookup ident (NE.head lsUses)) in case linearity of LTPrimitive -> if null completeUses then Left (NotUsed ann ident) else Right () LTBoxed -> - case length completeUses of - 0 -> Left (NotUsed ann ident) - 1 -> Right () - _more -> - Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident) + case NE.nonEmpty completeUses of + Nothing -> Left (NotUsed ann ident) + Just neUses -> + if length neUses == 1 + then Right () + else + Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident) in traverse_ validateFunctionItem (M.toList lsVars) --- | count uses of a given identifier -filterCompleteUses :: - [(Identifier, Linearity ann)] -> - Identifier -> - [Linearity ann] -filterCompleteUses uses ident = - foldr - ( \(thisIdent, linearity) total -> case linearity of - Whole _ -> - if thisIdent == ident then linearity : total else total - ) - [] - uses - getFunctionUses :: (Show ann) => Function (Type ann) -> diff --git a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs index c3a8ce88..51425438 100644 --- a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs +++ b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs @@ -241,28 +241,29 @@ spec = do LinearState { lsVars = M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))], - lsUses = NE.singleton [("b", Whole ()), ("a", Whole ())], + lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), ("a", NE.singleton $ Whole ())]), lsFresh = 0 } ), ( "function pair(a: a, b: b) -> (a,b) { (a,b) }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = NE.singleton [("b", Whole ()), ("a", Whole ())], + lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), + ("a", NE.singleton $ Whole ())]), lsFresh = 0 } ), ( "function dontUseA(a: a, b: b) -> b { b }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = NE.singleton [("b", Whole ())], + lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ())]), lsFresh = 0 } ), ( "function dup(a: a) -> (a,a) { (a,a)}", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))], - lsUses = NE.singleton [("a", Whole ()), ("a", Whole ())], + lsUses = NE.singleton (M.fromList [("a", NE.fromList [ Whole (),Whole () ])]), lsFresh = 0 } ) @@ -280,7 +281,7 @@ spec = do strings describe "validateFunction" $ do - fdescribe "expected successes" $ do + describe "expected successes" $ do let success = [ "function sum (a: Int64, b: Int64) -> Int64 { a + b }", "function pair(a: a, b: b) -> (a,b) { (a,b) }", @@ -302,7 +303,7 @@ spec = do ) success - fdescribe "expected failures" $ do + describe "expected failures" $ do let failures = [ ( "function dontUseA(a: a, b: b) -> b { b }", NotUsed () "a" @@ -311,13 +312,13 @@ spec = do NotUsed () "a" ), ( "function dup(a: a) -> (a,a) { (a,a)}", - UsedMultipleTimes [(), ()] "a" + UsedMultipleTimes (NE.fromList [(), ()]) "a" ), ( "function withPair(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }", - UsedMultipleTimes [(), ()] "a" + UsedMultipleTimes (NE.fromList [(), ()]) "a" ), ( "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then { let _ = pair; pair } else pair }", - UsedMultipleTimes [(), ()] "pair" + UsedMultipleTimes (NE.fromList [(), ()]) "pair" ) ] traverse_ From 1eccf3e29a2062b1934a53782989d502e51788c8 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Mon, 19 Aug 2024 15:50:30 +0100 Subject: [PATCH 3/4] Well, shit --- wasm-calc11/src/Calc/Linearity/Decorate.hs | 38 +++++++++++----------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs index f8811100..35375d25 100644 --- a/wasm-calc11/src/Calc/Linearity/Decorate.hs +++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs @@ -10,7 +10,6 @@ module Calc.Linearity.Decorate where import Data.Foldable (traverse_) -import Debug.Trace import Calc.ExprUtils import Calc.Linearity.Types import Calc.TypeUtils @@ -18,7 +17,7 @@ import Calc.Types.Expr import Calc.Types.Identifier import Calc.Types.Pattern import Calc.Types.Type -import Control.Monad (unless) +import Control.Monad (unless ) import Control.Monad.State import Control.Monad.Writer import Data.Bifunctor (second) @@ -169,6 +168,9 @@ getVarsInScope = gets (S.fromList . mapMaybe userDefined . M.keys . lsVars) UserDefined i -> Just i _ -> Nothing +combineWithBiggestItems :: (Ord k, Foldable t) => M.Map k (t a) -> M.Map k (t a) -> M.Map k (t a) +combineWithBiggestItems = M.unionWith (\l r -> if length r > length l then r else l) + decorate :: (Show ann) => ( MonadState (LinearState ann) m, @@ -198,23 +200,28 @@ decorate (EMatch ty expr pats) = do -- for vars currently in scope outside the pattern arms existingVars <- getVarsInScope - -- need to work out a way of scoping variables created in patterns - -- as they only exist in `patExpr` let decoratePair (pat, patExpr) = do (decoratedPat, _idents) <- decoratePattern pat - (decoratedPatExpr, patIdents) <- runWriterT (decorate patExpr) + ((decoratedPatExpr, uses), patIdents) <- runWriterT (scoped (decorate patExpr)) -- we only care about idents that exist in the current scope let usefulIdents = M.filterWithKey (\k _ -> S.member k existingVars) patIdents - pure (usefulIdents, (decoratedPat, decoratedPatExpr)) + pure ((usefulIdents, uses), (decoratedPat, decoratedPatExpr)) decoratedPatterns <- traverse decoratePair pats - let allIdents = foldMap fst decoratedPatterns + let allIdents = foldMap (fst . fst) decoratedPatterns + + let allUses = snd . fst <$> decoratedPatterns + combinedUses = foldr combineWithBiggestItems (NE.head allUses) (NE.tail allUses) + + -- here we're gonna go through each constructor and keep the longest list of + -- things, then push the ones we want to keep hold of + pushUses combinedUses -- now we know all the idents, we can decorate each pattern with the ones -- it's missing - let decorateWithIdents (idents, (pat, patExpr)) = + let decorateWithIdents ((idents,_), (pat, patExpr)) = let dropIdents = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference allIdents idents)) in (pat, mapOuterExprAnnotation (second (const dropIdents)) patExpr) @@ -224,21 +231,14 @@ decorate (EMatch ty expr pats) = do decorate (EInfix ty op a b) = EInfix (ty, Nothing) op <$> decorate a <*> decorate b decorate (EIf ty predExpr thenExpr elseExpr) = do - ((decoratedThen,thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr)) + ((decoratedThen, thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr)) ((decoratedElse, elseUses), elseIdents) <- runWriterT (scoped (decorate elseExpr)) - traceShowM ("thenUses" :: String, thenUses) - traceShowM ("elseUses" :: String, elseUses) - -- here we're gonna go through each constructor and keep the longest list of - -- things - let usesToKeep = thenUses <> elseUses - _ <- error "here is the problem" - - traceShowM ("usesToKeep" :: String, usesToKeep) + -- things, then push the ones we want to keep hold of + pushUses + ( combineWithBiggestItems thenUses elseUses) - -- push the ones we want to keep hold of - pushUses usesToKeep -- work out idents used in the other branch but not this one let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents)) From 9ecb086bb29f0c320591f953ea6244d6dffec460 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Mon, 19 Aug 2024 15:52:23 +0100 Subject: [PATCH 4/4] Well --- wasm-calc11/src/Calc/Linearity/Decorate.hs | 60 ++++++++++--------- wasm-calc11/src/Calc/Linearity/Error.hs | 2 +- wasm-calc11/src/Calc/Linearity/Types.hs | 1 - wasm-calc11/src/Calc/Linearity/Validate.hs | 11 ++-- .../test/Test/Linearity/LinearitySpec.hs | 11 +++- 5 files changed, 47 insertions(+), 38 deletions(-) diff --git a/wasm-calc11/src/Calc/Linearity/Decorate.hs b/wasm-calc11/src/Calc/Linearity/Decorate.hs index 35375d25..bfd3ce5b 100644 --- a/wasm-calc11/src/Calc/Linearity/Decorate.hs +++ b/wasm-calc11/src/Calc/Linearity/Decorate.hs @@ -9,7 +9,6 @@ module Calc.Linearity.Decorate ) where -import Data.Foldable (traverse_) import Calc.ExprUtils import Calc.Linearity.Types import Calc.TypeUtils @@ -17,10 +16,11 @@ import Calc.Types.Expr import Calc.Types.Identifier import Calc.Types.Pattern import Calc.Types.Type -import Control.Monad (unless ) +import Control.Monad (unless) import Control.Monad.State import Control.Monad.Writer import Data.Bifunctor (second) +import Data.Foldable (traverse_) import qualified Data.List.NonEmpty as NE import qualified Data.Map as M import Data.Maybe (mapMaybe) @@ -34,30 +34,38 @@ getFresh = do gets lsFresh -- | push a load of uses directly onto the head of the uses stack -pushUses :: (MonadState (LinearState ann) m) => - M.Map Identifier (NE.NonEmpty (Linearity ann)) -> m () +pushUses :: + (MonadState (LinearState ann) m) => + M.Map Identifier (NE.NonEmpty (Linearity ann)) -> + m () pushUses uses = - let pushForIdent ident items - = traverse_ (\(Whole ann) -> recordUsesInState ident ann) items - in traverse_ (uncurry pushForIdent) (M.toList uses) + let pushForIdent ident = + traverse_ (\(Whole ann) -> recordUsesInState ident ann) + in traverse_ (uncurry pushForIdent) (M.toList uses) mapHead :: (a -> a) -> NE.NonEmpty a -> NE.NonEmpty a mapHead f (neHead NE.:| neTail) = - (f neHead) NE.:| neTail + f neHead NE.:| neTail -recordUsesInState :: (MonadState (LinearState ann) m) => - Identifier -> ann -> m() +recordUsesInState :: + (MonadState (LinearState ann) m) => + Identifier -> + ann -> + m () recordUsesInState ident ann = - modify (\ls -> - let f = - M.alter (\existing -> - let newItem = Whole ann - in Just $ case existing of - Just neExisting -> newItem NE.:| (NE.toList neExisting) - Nothing -> NE.singleton newItem) ident - in - ls {lsUses = mapHead f (lsUses ls) }) - + modify + ( \ls -> + let f = + M.alter + ( \existing -> + let newItem = Whole ann + in Just $ case existing of + Just neExisting -> newItem NE.:| NE.toList neExisting + Nothing -> NE.singleton newItem + ) + ident + in ls {lsUses = mapHead f (lsUses ls)} + ) recordUse :: ( MonadState (LinearState ann) m, @@ -73,20 +81,19 @@ recordUse ident ty = do -- run an action, giving it a new uses scope -- then chop off the new values and return them -- this allows us to dedupe and re-add them to the current stack as desired -scoped :: (MonadState (LinearState ann) m) => m a -> m (a, (M.Map Identifier (NE.NonEmpty (Linearity ann)))) +scoped :: (MonadState (LinearState ann) m) => m a -> m (a, M.Map Identifier (NE.NonEmpty (Linearity ann))) scoped action = do -- add a new empty stack - modify (\ls -> ls { lsUses = mempty NE.:| (NE.toList $ lsUses ls) }) + modify (\ls -> ls {lsUses = mempty NE.:| NE.toList (lsUses ls)}) -- run the action, collecting uses in NE.head of uses stack result <- action -- grab the top level items items <- gets (NE.head . lsUses) -- bin them off stack - modify (\ls -> ls { lsUses = NE.fromList (NE.tail (lsUses ls)) }) + modify (\ls -> ls {lsUses = NE.fromList (NE.tail (lsUses ls))}) -- return both things pure (result, items) - isPrimitive :: Type ann -> Bool isPrimitive (TPrim {}) = True isPrimitive _ = False @@ -221,7 +228,7 @@ decorate (EMatch ty expr pats) = do -- now we know all the idents, we can decorate each pattern with the ones -- it's missing - let decorateWithIdents ((idents,_), (pat, patExpr)) = + let decorateWithIdents ((idents, _), (pat, patExpr)) = let dropIdents = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference allIdents idents)) in (pat, mapOuterExprAnnotation (second (const dropIdents)) patExpr) @@ -237,8 +244,7 @@ decorate (EIf ty predExpr thenExpr elseExpr) = do -- here we're gonna go through each constructor and keep the longest list of -- things, then push the ones we want to keep hold of pushUses - ( combineWithBiggestItems thenUses elseUses) - + (combineWithBiggestItems thenUses elseUses) -- work out idents used in the other branch but not this one let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents)) diff --git a/wasm-calc11/src/Calc/Linearity/Error.hs b/wasm-calc11/src/Calc/Linearity/Error.hs index 7e8f8f0e..46eeb9a8 100644 --- a/wasm-calc11/src/Calc/Linearity/Error.hs +++ b/wasm-calc11/src/Calc/Linearity/Error.hs @@ -11,12 +11,12 @@ where import Calc.SourceSpan import Calc.Types.Annotation import Calc.Types.Identifier +import qualified Data.List.NonEmpty as NE import Data.Maybe (catMaybes, mapMaybe) import qualified Data.Text as T import qualified Error.Diagnose as Diag import qualified Prettyprinter as PP import qualified Prettyprinter.Render.Text as PP -import qualified Data.List.NonEmpty as NE data LinearityError ann = NotUsed ann Identifier diff --git a/wasm-calc11/src/Calc/Linearity/Types.hs b/wasm-calc11/src/Calc/Linearity/Types.hs index 4024684d..333f2376 100644 --- a/wasm-calc11/src/Calc/Linearity/Types.hs +++ b/wasm-calc11/src/Calc/Linearity/Types.hs @@ -43,4 +43,3 @@ data LinearState ann = LinearState lsFresh :: Natural } deriving stock (Eq, Ord, Show, Functor) - diff --git a/wasm-calc11/src/Calc/Linearity/Validate.hs b/wasm-calc11/src/Calc/Linearity/Validate.hs index d9de22cd..558689d0 100644 --- a/wasm-calc11/src/Calc/Linearity/Validate.hs +++ b/wasm-calc11/src/Calc/Linearity/Validate.hs @@ -9,7 +9,6 @@ module Calc.Linearity.Validate ) where -import qualified Data.List.NonEmpty as NE import Calc.Linearity.Decorate import Calc.Linearity.Error import Calc.Linearity.Types @@ -25,6 +24,7 @@ import Control.Monad.State import Control.Monad.Writer import Data.Foldable (traverse_) import Data.Functor (($>)) +import qualified Data.List.NonEmpty as NE import qualified Data.Map as M getLinearityAnnotation :: Linearity ann -> ann @@ -40,7 +40,7 @@ validateGlobal :: Global (Type ann) -> Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann))) validateGlobal glob = do - let (expr, linearState) = getGlobalUses glob + let (expr, linearState) = getGlobalUses glob validate linearState $> expr validateFunction :: @@ -66,9 +66,8 @@ validate (LinearState {lsVars, lsUses}) = Nothing -> Left (NotUsed ann ident) Just neUses -> if length neUses == 1 - then Right () - else - Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident) + then Right () + else Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident) in traverse_ validateFunctionItem (M.toList lsVars) getFunctionUses :: @@ -99,7 +98,7 @@ getFunctionUses (Function {fnBody, fnArgs}) = getGlobalUses :: (Show ann) => Global (Type ann) -> - (Expr (Type ann, Maybe (Drops ann)), LinearState ann) + (Expr (Type ann, Maybe (Drops ann)), LinearState ann) getGlobalUses (Global {glbExpr}) = fst $ runIdentity $ runWriterT $ runStateT action initialState where diff --git a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs index 51425438..b117ef09 100644 --- a/wasm-calc11/test/Test/Linearity/LinearitySpec.hs +++ b/wasm-calc11/test/Test/Linearity/LinearitySpec.hs @@ -248,8 +248,13 @@ spec = do ( "function pair(a: a, b: b) -> (a,b) { (a,b) }", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], - lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), - ("a", NE.singleton $ Whole ())]), + lsUses = + NE.singleton + ( M.fromList + [ ("b", NE.singleton $ Whole ()), + ("a", NE.singleton $ Whole ()) + ] + ), lsFresh = 0 } ), @@ -263,7 +268,7 @@ spec = do ( "function dup(a: a) -> (a,a) { (a,a)}", LinearState { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))], - lsUses = NE.singleton (M.fromList [("a", NE.fromList [ Whole (),Whole () ])]), + lsUses = NE.singleton (M.fromList [("a", NE.fromList [Whole (), Whole ()])]), lsFresh = 0 } )