Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branching in linearity checking #41

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions wasm-calc11/src/Calc/Linearity/Decorate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import Control.Monad (unless)
import Control.Monad.State
import Control.Monad.Writer
import Data.Bifunctor (second)
import Data.Foldable (traverse_)
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
Expand All @@ -32,6 +33,40 @@ getFresh = do
modify (\ls -> ls {lsFresh = lsFresh ls + 1})
gets lsFresh

-- | push a load of uses directly onto the head of the uses stack
pushUses ::
(MonadState (LinearState ann) m) =>
M.Map Identifier (NE.NonEmpty (Linearity ann)) ->
m ()
pushUses uses =
let pushForIdent ident =
traverse_ (\(Whole ann) -> recordUsesInState ident ann)
in traverse_ (uncurry pushForIdent) (M.toList uses)

mapHead :: (a -> a) -> NE.NonEmpty a -> NE.NonEmpty a
mapHead f (neHead NE.:| neTail) =
f neHead NE.:| neTail

recordUsesInState ::
(MonadState (LinearState ann) m) =>
Identifier ->
ann ->
m ()
recordUsesInState ident ann =
modify
( \ls ->
let f =
M.alter
( \existing ->
let newItem = Whole ann
in Just $ case existing of
Just neExisting -> newItem NE.:| NE.toList neExisting
Nothing -> NE.singleton newItem
)
ident
in ls {lsUses = mapHead f (lsUses ls)}
)

recordUse ::
( MonadState (LinearState ann) m,
MonadWriter (M.Map Identifier (Type ann)) m
Expand All @@ -40,9 +75,25 @@ recordUse ::
Type ann ->
m ()
recordUse ident ty = do
modify (\ls -> ls {lsUses = (ident, Whole (getOuterTypeAnnotation ty)) : lsUses ls})
recordUsesInState ident (getOuterTypeAnnotation ty)
unless (isPrimitive ty) $ tell (M.singleton ident ty) -- we only want to track use of non-primitive types

-- run an action, giving it a new uses scope
-- then chop off the new values and return them
-- this allows us to dedupe and re-add them to the current stack as desired
scoped :: (MonadState (LinearState ann) m) => m a -> m (a, M.Map Identifier (NE.NonEmpty (Linearity ann)))
scoped action = do
-- add a new empty stack
modify (\ls -> ls {lsUses = mempty NE.:| NE.toList (lsUses ls)})
-- run the action, collecting uses in NE.head of uses stack
result <- action
-- grab the top level items
items <- gets (NE.head . lsUses)
-- bin them off stack
modify (\ls -> ls {lsUses = NE.fromList (NE.tail (lsUses ls))})
-- return both things
pure (result, items)

isPrimitive :: Type ann -> Bool
isPrimitive (TPrim {}) = True
isPrimitive _ = False
Expand Down Expand Up @@ -124,6 +175,9 @@ getVarsInScope = gets (S.fromList . mapMaybe userDefined . M.keys . lsVars)
UserDefined i -> Just i
_ -> Nothing

combineWithBiggestItems :: (Ord k, Foldable t) => M.Map k (t a) -> M.Map k (t a) -> M.Map k (t a)
combineWithBiggestItems = M.unionWith (\l r -> if length r > length l then r else l)

decorate ::
(Show ann) =>
( MonadState (LinearState ann) m,
Expand Down Expand Up @@ -153,23 +207,28 @@ decorate (EMatch ty expr pats) = do
-- for vars currently in scope outside the pattern arms
existingVars <- getVarsInScope

-- need to work out a way of scoping variables created in patterns
-- as they only exist in `patExpr`
let decoratePair (pat, patExpr) = do
(decoratedPat, _idents) <- decoratePattern pat
(decoratedPatExpr, patIdents) <- runWriterT (decorate patExpr)
((decoratedPatExpr, uses), patIdents) <- runWriterT (scoped (decorate patExpr))
-- we only care about idents that exist in the current scope
let usefulIdents =
M.filterWithKey (\k _ -> S.member k existingVars) patIdents
pure (usefulIdents, (decoratedPat, decoratedPatExpr))
pure ((usefulIdents, uses), (decoratedPat, decoratedPatExpr))

decoratedPatterns <- traverse decoratePair pats

let allIdents = foldMap fst decoratedPatterns
let allIdents = foldMap (fst . fst) decoratedPatterns

let allUses = snd . fst <$> decoratedPatterns
combinedUses = foldr combineWithBiggestItems (NE.head allUses) (NE.tail allUses)

-- here we're gonna go through each constructor and keep the longest list of
-- things, then push the ones we want to keep hold of
pushUses combinedUses

-- now we know all the idents, we can decorate each pattern with the ones
-- it's missing
let decorateWithIdents (idents, (pat, patExpr)) =
let decorateWithIdents ((idents, _), (pat, patExpr)) =
let dropIdents = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference allIdents idents))
in (pat, mapOuterExprAnnotation (second (const dropIdents)) patExpr)

Expand All @@ -179,8 +238,13 @@ decorate (EMatch ty expr pats) = do
decorate (EInfix ty op a b) =
EInfix (ty, Nothing) op <$> decorate a <*> decorate b
decorate (EIf ty predExpr thenExpr elseExpr) = do
(decoratedThen, thenIdents) <- runWriterT (decorate thenExpr)
(decoratedElse, elseIdents) <- runWriterT (decorate elseExpr)
((decoratedThen, thenUses), thenIdents) <- runWriterT (scoped (decorate thenExpr))
((decoratedElse, elseUses), elseIdents) <- runWriterT (scoped (decorate elseExpr))

-- here we're gonna go through each constructor and keep the longest list of
-- things, then push the ones we want to keep hold of
pushUses
(combineWithBiggestItems thenUses elseUses)

-- work out idents used in the other branch but not this one
let uniqueToThen = DropIdentifiers <$> NE.nonEmpty (M.toList (M.difference thenIdents elseIdents))
Expand Down
5 changes: 3 additions & 2 deletions wasm-calc11/src/Calc/Linearity/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ where
import Calc.SourceSpan
import Calc.Types.Annotation
import Calc.Types.Identifier
import qualified Data.List.NonEmpty as NE
import Data.Maybe (catMaybes, mapMaybe)
import qualified Data.Text as T
import qualified Error.Diagnose as Diag
Expand All @@ -19,7 +20,7 @@ import qualified Prettyprinter.Render.Text as PP

data LinearityError ann
= NotUsed ann Identifier
| UsedMultipleTimes [ann] Identifier
| UsedMultipleTimes (NE.NonEmpty ann) Identifier
deriving stock (Eq, Ord, Show)

prettyPrint :: PP.Doc doc -> T.Text
Expand Down Expand Up @@ -83,7 +84,7 @@ linearityErrorDiagnostic input e =
)
)
)
anns
(NE.toList anns)
)
[]
in Diag.addReport diag report
Expand Down
2 changes: 1 addition & 1 deletion wasm-calc11/src/Calc/Linearity/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ data UserDefined a = UserDefined a | Internal a

data LinearState ann = LinearState
{ lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann),
lsUses :: [(Identifier, Linearity ann)],
lsUses :: NE.NonEmpty (M.Map Identifier (NE.NonEmpty (Linearity ann))),
lsFresh :: Natural
}
deriving stock (Eq, Ord, Show, Functor)
40 changes: 14 additions & 26 deletions wasm-calc11/src/Calc/Linearity/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Control.Monad.State
import Control.Monad.Writer
import Data.Foldable (traverse_)
import Data.Functor (($>))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map as M

getLinearityAnnotation :: Linearity ann -> ann
Expand All @@ -38,50 +39,37 @@ validateGlobal ::
(Show ann) =>
Global (Type ann) ->
Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann)))
validateGlobal glob =
validateGlobal glob = do
let (expr, linearState) = getGlobalUses glob
in validate linearState $> expr
validate linearState $> expr

validateFunction ::
(Show ann) =>
Function (Type ann) ->
Either (LinearityError ann) (Expr (Type ann, Maybe (Drops ann)))
validateFunction fn =
validateFunction fn = do
let (expr, linearState) = getFunctionUses fn
in validate linearState $> expr
validate linearState $> expr

validate :: LinearState ann -> Either (LinearityError ann) ()
validate (LinearState {lsVars, lsUses}) =
let validateFunctionItem (Internal _, _) = Right ()
validateFunctionItem (UserDefined ident, (linearity, ann)) =
let completeUses = filterCompleteUses lsUses ident
let completeUses = maybe mempty NE.toList (M.lookup ident (NE.head lsUses))
in case linearity of
LTPrimitive ->
if null completeUses
then Left (NotUsed ann ident)
else Right ()
LTBoxed ->
case length completeUses of
0 -> Left (NotUsed ann ident)
1 -> Right ()
_more ->
Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident)
case NE.nonEmpty completeUses of
Nothing -> Left (NotUsed ann ident)
Just neUses ->
if length neUses == 1
then Right ()
else Left (UsedMultipleTimes (getLinearityAnnotation <$> neUses) ident)
in traverse_ validateFunctionItem (M.toList lsVars)

-- | count uses of a given identifier
filterCompleteUses ::
[(Identifier, Linearity ann)] ->
Identifier ->
[Linearity ann]
filterCompleteUses uses ident =
foldr
( \(thisIdent, linearity) total -> case linearity of
Whole _ ->
if thisIdent == ident then linearity : total else total
)
[]
uses

getFunctionUses ::
(Show ann) =>
Function (Type ann) ->
Expand All @@ -94,7 +82,7 @@ getFunctionUses (Function {fnBody, fnArgs}) =
initialState =
LinearState
{ lsVars = initialVars,
lsUses = mempty,
lsUses = NE.singleton mempty,
lsFresh = 0
}

Expand All @@ -119,6 +107,6 @@ getGlobalUses (Global {glbExpr}) =
initialState =
LinearState
{ lsVars = mempty,
lsUses = mempty,
lsUses = NE.singleton mempty,
lsFresh = 0
}
25 changes: 18 additions & 7 deletions wasm-calc11/test/Test/Linearity/LinearitySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -241,28 +241,34 @@ spec = do
LinearState
{ lsVars =
M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))],
lsUses = [("b", Whole ()), ("a", Whole ())],
lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ()), ("a", NE.singleton $ Whole ())]),
lsFresh = 0
}
),
( "function pair<a,b>(a: a, b: b) -> (a,b) { (a,b) }",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
lsUses = [("b", Whole ()), ("a", Whole ())],
lsUses =
NE.singleton
( M.fromList
[ ("b", NE.singleton $ Whole ()),
("a", NE.singleton $ Whole ())
]
),
lsFresh = 0
}
),
( "function dontUseA<a,b>(a: a, b: b) -> b { b }",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))],
lsUses = [("b", Whole ())],
lsUses = NE.singleton (M.fromList [("b", NE.singleton $ Whole ())]),
lsFresh = 0
}
),
( "function dup<a>(a: a) -> (a,a) { (a,a)}",
LinearState
{ lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))],
lsUses = [("a", Whole ()), ("a", Whole ())],
lsUses = NE.singleton (M.fromList [("a", NE.fromList [Whole (), Whole ()])]),
lsFresh = 0
}
)
Expand All @@ -286,7 +292,9 @@ spec = do
"function pair<a,b>(a: a, b: b) -> (a,b) { (a,b) }",
"function addPair(pair: (Int64,Int64)) -> Int64 { let (a,b) = pair; a + b }",
"function fst<a,b>(pair: (a,b)) -> Box(a) { let (a,_) = pair; Box(a) }",
"function main() -> Int64 { let _ = (1: Int64); 2 }"
"function main() -> Int64 { let _ = (1: Int64); 2 }",
"function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then pair else pair }",
"function bothSidesOfMatch() -> (Boolean,Boolean) { let pair = (True,False); case True { True -> pair, False -> pair } }"
]
traverse_
( \str -> it (T.unpack str) $ do
Expand All @@ -309,10 +317,13 @@ spec = do
NotUsed () "a"
),
( "function dup<a>(a: a) -> (a,a) { (a,a)}",
UsedMultipleTimes [(), ()] "a"
UsedMultipleTimes (NE.fromList [(), ()]) "a"
),
( "function withPair<a,b>(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }",
UsedMultipleTimes [(), ()] "a"
UsedMultipleTimes (NE.fromList [(), ()]) "a"
),
( "function bothSidesOfIf() -> (Boolean,Boolean) { let pair = (True,False); if True then { let _ = pair; pair } else pair }",
UsedMultipleTimes (NE.fromList [(), ()]) "pair"
)
]
traverse_
Expand Down
Loading