From 6b33be5540f5ca8efd8795783f99f56637390364 Mon Sep 17 00:00:00 2001 From: Daniel Harvey Date: Sun, 2 Jun 2024 22:00:58 +0100 Subject: [PATCH] Some sort of fix --- wasm-calc9/src/Calc/Linearity/Types.hs | 22 +++++++---- wasm-calc9/src/Calc/Linearity/Validate.hs | 39 +++++++++++++------ .../test/Test/Linearity/LinearitySpec.hs | 18 ++++----- 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/wasm-calc9/src/Calc/Linearity/Types.hs b/wasm-calc9/src/Calc/Linearity/Types.hs index a4a2c8f5..5ba87abd 100644 --- a/wasm-calc9/src/Calc/Linearity/Types.hs +++ b/wasm-calc9/src/Calc/Linearity/Types.hs @@ -1,18 +1,19 @@ -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} module Calc.Linearity.Types ( Linearity (..), LinearityType (..), LinearState (..), + UserDefined (..) ) where -import Calc.Types.Identifier -import qualified Data.Map as M -import GHC.Natural +import Calc.Types.Identifier +import qualified Data.Map as M +import GHC.Natural -- | Are we using the whole type or bits of it? -- this distinction will be gone once we can destructure types instead, @@ -26,9 +27,14 @@ newtype Linearity ann data LinearityType = LTPrimitive | LTBoxed deriving stock (Eq, Ord, Show) +-- | differentiate between names provided by a user, and variables +-- created during linearity check to allow us to drop unnamed items +data UserDefined a = UserDefined a | Internal a + deriving stock (Eq,Ord,Show,Functor) + data LinearState ann = LinearState - { lsVars :: M.Map Identifier (LinearityType, ann), - lsUses :: [(Identifier, Linearity ann)], + { lsVars :: M.Map (UserDefined Identifier) (LinearityType, ann), + lsUses :: [(Identifier, Linearity ann)], lsFresh :: Natural } deriving stock (Eq, Ord, Show, Functor) diff --git a/wasm-calc9/src/Calc/Linearity/Validate.hs b/wasm-calc9/src/Calc/Linearity/Validate.hs index b976047c..d40f8396 100644 --- a/wasm-calc9/src/Calc/Linearity/Validate.hs +++ b/wasm-calc9/src/Calc/Linearity/Validate.hs @@ -73,17 +73,17 @@ validateFunction fn = validate :: LinearState ann -> Either (LinearityError ann) () validate (LinearState {lsVars, lsUses}) = - let validateFunctionItem (ident, (linearity, _ann)) = + let validateFunctionItem (Internal _,_) = Right () + validateFunctionItem (UserDefined ident, (linearity, ann)) = let completeUses = filterCompleteUses lsUses ident in case linearity of - LTPrimitive -> Right () - {- if null completeUses + LTPrimitive -> + if null completeUses then Left (NotUsed ann ident) - else Right () -} + else Right () LTBoxed -> case length completeUses of - 0 -> Right () - -- Left (NotUsed ann ident) + 0 -> Left (NotUsed ann ident) 1 -> Right () _more -> Left (UsedMultipleTimes (getLinearityAnnotation <$> completeUses) ident) @@ -122,7 +122,7 @@ getFunctionUses (Function {fnBody, fnArgs}) = initialVars = foldMap ( \(FunctionArg {faAnn, faName = ArgumentName arg, faType}) -> - M.singleton (Identifier arg) $ case faType of + M.singleton (UserDefined (Identifier arg)) $ case faType of TPrim {} -> (LTPrimitive, getOuterTypeAnnotation faAnn) _ -> (LTBoxed, getOuterTypeAnnotation faAnn) ) @@ -169,7 +169,7 @@ addLetBinding (PVar ty ident) = do ls { lsVars = M.insert - ident + (UserDefined ident) ( if isPrimitive ty then LTPrimitive else LTBoxed, getOuterTypeAnnotation ty ) @@ -179,15 +179,30 @@ addLetBinding (PVar ty ident) = do pure $ PVar (ty, Nothing) ident addLetBinding (PWildcard ty) = do i <- getFresh - let name = Identifier $ "_fresh_name" <> T.pack (show i) - addLetBinding $ PVar ty name + let ident = Identifier $ "_fresh_name" <> T.pack (show i) + modify + ( \ls -> + ls + { lsVars = + M.insert + (Internal ident) + ( if isPrimitive ty then LTPrimitive else LTBoxed, + getOuterTypeAnnotation ty + ) + (lsVars ls) + } + ) + pure $ PVar (ty, dropForType ty) ident addLetBinding (PBox ty pat) = - PBox (ty, Just DropMe) <$> addLetBinding pat + PBox (ty, dropForType ty) <$> addLetBinding pat addLetBinding (PTuple ty p ps) = do - PTuple (ty, Just DropMe) + PTuple (ty, dropForType ty ) <$> addLetBinding p <*> traverse addLetBinding ps +dropForType :: Type ann -> Maybe (Drops an) +dropForType ty = if isPrimitive ty then Nothing else Just DropMe + decorate :: (Show ann) => (MonadState (LinearState ann) m, MonadWriter (M.Map Identifier (Type ann)) m) => diff --git a/wasm-calc9/test/Test/Linearity/LinearitySpec.hs b/wasm-calc9/test/Test/Linearity/LinearitySpec.hs index 0662ef55..86e8333e 100644 --- a/wasm-calc9/test/Test/Linearity/LinearitySpec.hs +++ b/wasm-calc9/test/Test/Linearity/LinearitySpec.hs @@ -68,6 +68,9 @@ spec = do ) (dVar "a") ), + ("function allocUnused() -> Int64 { let _ = Box((1: Int32)); 22 }", + ELet Nothing (PVar (Just DropMe) "_fresh_name1") (EBox Nothing (EAnn Nothing dTyInt32 (dInt 1))) (dInt 22)), + ( "function incrementallyDropBoxesAfterUse() -> Int64 { let Box(outer) = Box(Box((100: Int64))); let Box(inner) = outer; inner }", ELet Nothing @@ -152,28 +155,28 @@ spec = do [ ( "function sum (a: Int64, b: Int64) -> Int64 { a + b }", LinearState { lsVars = - M.fromList [("a", (LTPrimitive, ())), ("b", (LTPrimitive, ()))], + M.fromList [(UserDefined "a", (LTPrimitive, ())), (UserDefined "b", (LTPrimitive, ()))], lsUses = [("b", Whole ()), ("a", Whole ())], lsFresh = 0 } ), ( "function pair(a: a, b: b) -> (a,b) { (a,b) }", LinearState - { lsVars = M.fromList [("a", (LTBoxed, ())), ("b", (LTBoxed, ()))], + { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], lsUses = [("b", Whole ()), ("a", Whole ())], lsFresh = 0 } ), ( "function dontUseA(a: a, b: b) -> b { b }", LinearState - { lsVars = M.fromList [("a", (LTBoxed, ())), ("b", (LTBoxed, ()))], + { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ())), (UserDefined "b", (LTBoxed, ()))], lsUses = [("b", Whole ())], lsFresh = 0 } ), ( "function dup(a: a) -> (a,a) { (a,a)}", LinearState - { lsVars = M.fromList [("a", (LTBoxed, ()))], + { lsVars = M.fromList [(UserDefined "a", (LTBoxed, ()))], lsUses = [("a", Whole ()), ("a", Whole ())], lsFresh = 0 } @@ -214,18 +217,15 @@ spec = do describe "expected failures" $ do let failures = - [ {-( "function dontUseA(a: a, b: b) -> b { b }", + [ ( "function dontUseA(a: a, b: b) -> b { b }", NotUsed () "a" ), ( "function dontUsePrimA(a: Int64, b: Int64) -> Int64 { b }", NotUsed () "a" - ),-} + ), ( "function dup(a: a) -> (a,a) { (a,a)}", UsedMultipleTimes [(), ()] "a" ), - {-( "function twice(pair: (Int64, Int64)) { pair.1 + pair.2 }", - UsedMultipleTimes "pair" - ),-} ( "function withPair(pair: (a,b)) -> (a,a,b) { let (a,b) = pair; (a, a, b) }", UsedMultipleTimes [(), ()] "a" )