Skip to content

Commit

Permalink
Merge pull request #1334 from google-research/primitive-name-map-e
Browse files Browse the repository at this point in the history
Redefine NameMapE as the primitive type
  • Loading branch information
axch authored Aug 1, 2023
2 parents c7373b2 + c372cba commit 3cbde4c
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 98 deletions.
15 changes: 8 additions & 7 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ liftTyperM cont =
affineUsed :: AtomName r o -> TyperM r i o ()
affineUsed name = TyperM $ do
affines <- get
case lookupNameMap name affines of
Just n -> if n > 0 then
throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times."
else
put $ insertNameMap name (n + 1) affines
Nothing -> put $ insertNameMap name 1 affines
case lookupNameMapE name affines of
Just (LiftE n) ->
if n > 0 then
throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times."
else
put $ insertNameMapE name (LiftE $ n + 1) affines
Nothing -> put $ insertNameMapE name (LiftE 1) affines

parallelAffines :: [TyperM r i o a] -> TyperM r i o [a]
parallelAffines actions = TyperM $ do
Expand All @@ -77,7 +78,7 @@ parallelAffines actions = TyperM $ do
result <- runTyperT' act
(result,) <$> get
put affines
forM_ (toListNameMap $ unionsWithNameMap max isolateds) \(name, ct) ->
forM_ (toListNameMapE $ unionsWithNameMapE max isolateds) \(name, (LiftE ct)) ->
case ct of
0 -> return ()
1 -> runTyperT' $ affineUsed name
Expand Down
2 changes: 0 additions & 2 deletions src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -482,5 +482,3 @@ freshNameM hint = do
Distinct <- getDistinct
return $ withFresh hint scope \b -> Abs b (binderName b)
{-# INLINE freshNameM #-}

type AtomNameMap r = NameMap (AtomNameC r)
11 changes: 6 additions & 5 deletions src/lib/Lower.hs
Original file line number Diff line number Diff line change
Expand Up @@ -217,17 +217,18 @@ lowerCase maybeDest scrut alts resultTy = do
-- so that it never allocates scratch space for its result, but will put it directly in
-- the corresponding slice of the full 2D buffer.

type DestAssignment (i'::S) (o::S) = AtomNameMap SimpIR (ProjDest o) i'
type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (ProjDest o) i'

data ProjDest o
= FullDest (Dest SimpIR o)
| ProjDest (NE.NonEmpty Projection) (Dest SimpIR o) -- dest corresponds to the projection applied to name
deriving (Show)

instance SinkableE ProjDest where
sinkingProofE = todoSinkableProof

lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o)
lookupDest = flip lookupNameMap
lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests

-- Matches up the free variables of the atom, with the given dest. For example, if the
-- atom is a pair of two variables, the dest might be split into per-component dests,
Expand All @@ -238,10 +239,10 @@ lookupDest = flip lookupNameMap
-- XXX: When adding more cases, be careful about potentially repeated vars in the output!
decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o))
decomposeDest dest = \case
Var v -> return $ Just $ singletonNameMap (atomVarName v) $ FullDest dest
Var v -> return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest
ProjectElt _ p x -> do
(ps, v) <- return $ asNaryProj p x
return $ Just $ singletonNameMap (atomVarName v) $ ProjDest ps dest
return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ ProjDest ps dest
_ -> return Nothing

lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o)
Expand All @@ -258,7 +259,7 @@ lowerBlockWithDest dest (Abs decls ans) = do
Just DistinctBetween -> do
s' <- traverseDeclNestWithDestS destMap s decls
-- But we have to emit explicit writes, for all the vars that are not defined in decls!
forM_ (toListNameMap $ hoistFilterNameMap decls destMap) \(n, d) -> do
forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do
x <- case s ! n of
Rename v -> Var <$> toAtomVar v
SubstVal a -> return a
Expand Down
4 changes: 2 additions & 2 deletions src/lib/MTL1.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ instance HoistableState UnitE where
hoistState _ _ UnitE = UnitE
{-# INLINE hoistState #-}

instance HoistableState (NameMap c a) where
hoistState _ b m = hoistFilterNameMap b m
instance Show a => HoistableState (NameMap c a) where
hoistState _ b m = hoistNameMap b m
{-# INLINE hoistState #-}

-------------------- ScopedT1 --------------------
Expand Down
128 changes: 55 additions & 73 deletions src/lib/Name.hs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ newtype NonEmptyListE (e::E) (n::S) = NonEmptyListE { fromNonEmptyListE :: NonEm
deriving (Show, Eq, Generic)

newtype LiftE (a:: *) (n::S) = LiftE { fromLiftE :: a }
deriving (Show, Eq, Generic, Monoid, Semigroup)
deriving (Show, Eq, Ord, Generic, Monoid, Semigroup)

newtype ComposeE (f :: * -> *) (e::E) (n::S) =
ComposeE { fromComposeE :: (f (e n)) }
Expand Down Expand Up @@ -3256,113 +3256,85 @@ instance HoistableB b => HoistableB (WithAttrB a b) where

-- === extra data structures ===

-- A map from names in some scope to values that do not contain names. This is
-- not trying to enforce completeness -- a name in the scope can fail to be in
-- the map.

-- Hoisting the map removes entries that are no longer in scope.

newtype NameMap (c::C) (a:: *) (n::S) = UnsafeNameMap (RawNameMap a)
deriving (Eq, Semigroup, Monoid, Store)

hoistFilterNameMap :: BindsNames b => b n l -> NameMap c a l -> NameMap c a n
hoistFilterNameMap b (UnsafeNameMap raw) =
UnsafeNameMap $ raw `R.difference` frag
where UnsafeMakeScopeFrag frag = toScopeFrag b
{-# INLINE hoistFilterNameMap #-}

insertNameMap :: Name c n -> a -> NameMap c a n -> NameMap c a n
insertNameMap (UnsafeMakeName n) x (UnsafeNameMap raw) = UnsafeNameMap $ R.insert n x raw
{-# INLINE insertNameMap #-}

lookupNameMap :: Name c n -> NameMap c a n -> Maybe a
lookupNameMap (UnsafeMakeName n) (UnsafeNameMap raw) = R.lookup n raw
{-# INLINE lookupNameMap #-}

singletonNameMap :: Name c n -> a -> NameMap c a n
singletonNameMap (UnsafeMakeName n) x = UnsafeNameMap $ R.singleton n x
{-# INLINE singletonNameMap #-}

toListNameMap :: NameMap c a n -> [(Name c n, a)]
toListNameMap (UnsafeNameMap raw) = R.toList raw <&> \(r, x) -> (UnsafeMakeName r, x)
{-# INLINE toListNameMap #-}

unionWithNameMap :: (a -> a -> a) -> NameMap c a n -> NameMap c a n -> NameMap c a n
unionWithNameMap f (UnsafeNameMap raw1) (UnsafeNameMap raw2) =
UnsafeNameMap $ R.unionWith f raw1 raw2
{-# INLINE unionWithNameMap #-}

unionsWithNameMap :: (Foldable f) => (a -> a -> a) -> f (NameMap c a n) -> NameMap c a n
unionsWithNameMap func maps =
foldl' (unionWithNameMap func) mempty maps
{-# INLINE unionsWithNameMap #-}

traverseNameMap :: (Applicative f) => (a -> f b)
-> NameMap c a n -> f (NameMap c b n)
traverseNameMap f (UnsafeNameMap raw) = UnsafeNameMap <$> traverse f raw
{-# INLINE traverseNameMap #-}

mapNameMap :: (a -> b) -> NameMap c a n -> (NameMap c b n)
mapNameMap f (UnsafeNameMap raw) = UnsafeNameMap $ fmap f raw
{-# INLINE mapNameMap #-}

keysNameMap :: NameMap c a n -> [Name c n]
keysNameMap = map fst . toListNameMap
{-# INLINE keysNameMap #-}

keySetNameMap :: (Color c) => NameMap c a n -> NameSet n
keySetNameMap nmap = freeVarsE $ ListE $ keysNameMap nmap

instance SinkableE (NameMap c a) where
sinkingProofE = undefined
-- A map from names in some scope to values that may contain names
-- from the same scope. This is not trying to enforce completeness --
-- a name in the scope can fail to be in the map.

-- Hoisting the map removes entries for names that are no longer in
-- scope, and then attempts to hoist the remaining values.

-- This structure is useful for bottom-up code traversals. Once one
-- has traversed some term in scope n, one may be carrying information
-- associated with (some of) the free variables of the term. These
-- free variables are necessarily in the scope n, though they need by
-- no means be all the names in the scope n (that's what a Subst is
-- for). But, if the traversal is alpha-invariant, it cannot be
-- carrying any information about names bound within the term, only
-- the free ones.
--
-- Further, if the information being carried is E-kinded, the names
-- therein should be resolvable in the same scope n, since those are
-- the only names that are given meaning by the context of the term
-- being traversed.

newtype NameMapE (c::C) (e:: E) (n::S) = NameMapE (NameMap c (e n) n)
newtype NameMapE (c::C) (e:: E) (n::S) = UnsafeNameMapE (RawNameMap (e n))
deriving (Eq, Semigroup, Monoid, Store)

-- Filters out the entry(ies) for the binder being hoisted above,
-- and hoists the values of the remaining entries.
hoistNameMapE :: (BindsNames b, HoistableE e, ShowE e)
=> b n l -> NameMapE c e l -> HoistExcept (NameMapE c e n)
hoistNameMapE b (NameMapE nmap) =
NameMapE <$> (traverseNameMap (hoist b) $ hoistFilterNameMap b nmap) where
hoistNameMapE b (UnsafeNameMapE raw) =
UnsafeNameMapE <$> traverse (hoist b) diff
where
diff = raw `R.difference` frag
UnsafeMakeScopeFrag frag = toScopeFrag b
{-# INLINE hoistNameMapE #-}

insertNameMapE :: Name c n -> e n -> NameMapE c e n -> NameMapE c e n
insertNameMapE n x (NameMapE nmap) = NameMapE $ insertNameMap n x nmap
insertNameMapE (UnsafeMakeName n) x (UnsafeNameMapE raw)
= UnsafeNameMapE $ R.insert n x raw
{-# INLINE insertNameMapE #-}

lookupNameMapE :: Name c n -> NameMapE c e n -> Maybe (e n)
lookupNameMapE n (NameMapE nmap) = lookupNameMap n nmap
lookupNameMapE (UnsafeMakeName n) (UnsafeNameMapE raw) = R.lookup n raw
{-# INLINE lookupNameMapE #-}

singletonNameMapE :: Name c n -> e n -> NameMapE c e n
singletonNameMapE n x = NameMapE $ singletonNameMap n x
singletonNameMapE (UnsafeMakeName n) x = UnsafeNameMapE $ R.singleton n x
{-# INLINE singletonNameMapE #-}

toListNameMapE :: NameMapE c e n -> [(Name c n, (e n))]
toListNameMapE (NameMapE nmap) = toListNameMap nmap
toListNameMapE (UnsafeNameMapE raw) =
R.toList raw <&> \(r, x) -> (UnsafeMakeName r, x)
{-# INLINE toListNameMapE #-}

unionWithNameMapE :: (e n -> e n -> e n) -> NameMapE c e n -> NameMapE c e n -> NameMapE c e n
unionWithNameMapE f (NameMapE nmap1) (NameMapE nmap2) =
NameMapE $ unionWithNameMap f nmap1 nmap2
unionWithNameMapE f (UnsafeNameMapE raw1) (UnsafeNameMapE raw2) =
UnsafeNameMapE $ R.unionWith f raw1 raw2
{-# INLINE unionWithNameMapE #-}

unionsWithNameMapE :: (Foldable f) => (e n -> e n -> e n) -> f (NameMapE c e n) -> NameMapE c e n
unionsWithNameMapE func maps =
foldl' (unionWithNameMapE func) mempty maps
{-# INLINE unionsWithNameMapE #-}

traverseNameMapE :: (Applicative f) => (e1 n -> f (e2 n))
-> NameMapE c e1 n -> f (NameMapE c e2 n)
traverseNameMapE f (NameMapE nmap) = NameMapE <$> traverseNameMap f nmap
traverseNameMapE f (UnsafeNameMapE raw) = UnsafeNameMapE <$> traverse f raw
{-# INLINE traverseNameMapE #-}

mapNameMapE :: (e1 n -> e2 n)
-> NameMapE c e1 n -> NameMapE c e2 n
mapNameMapE f (NameMapE nmap) = NameMapE $ mapNameMap f nmap
mapNameMapE f (UnsafeNameMapE raw) = UnsafeNameMapE $ fmap f raw
{-# INLINE mapNameMapE #-}

keysNameMapE :: NameMapE c e n -> [Name c n]
keysNameMapE (NameMapE nmap) = keysNameMap nmap
keysNameMapE = map fst . toListNameMapE
{-# INLINE keysNameMapE #-}

keySetNameMapE :: (Color c) => NameMapE c e n -> NameSet n
keySetNameMapE (NameMapE nmap) = keySetNameMap nmap
keySetNameMapE nmap = freeVarsE $ ListE $ keysNameMapE nmap

instance SinkableE e => SinkableE (NameMapE c e) where
sinkingProofE = undefined
Expand All @@ -3373,6 +3345,16 @@ instance RenameE e => RenameE (NameMapE c e) where
instance HoistableE e => HoistableE (NameMapE c e) where
freeVarsE = undefined

-- A small short-cut: When the information in a NameMapE does not, in
-- fact, reference any names, hoisting the entries cannot fail.

type NameMap (c::C) (a:: *) = NameMapE c (LiftE a)

hoistNameMap :: (BindsNames b, Show a)
=> b n l -> NameMap c a l -> (NameMap c a n)
hoistNameMap b = ignoreHoistFailure . hoistNameMapE b
{-# INLINE hoistNameMap #-}

-- === E-kinded IR coercions ===

-- XXX: the intention is that we won't have to use this much
Expand Down
8 changes: 3 additions & 5 deletions src/lib/Occurrence.hs
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ class MaxPlus a where
max :: a -> a -> a
plus :: a -> a -> a

instance (MaxPlus a) => MaxPlus (NameMap c a n) where
instance (MaxPlus (e n)) => MaxPlus (NameMapE c e n) where
zero = mempty
max = unionWithNameMap max
plus = unionWithNameMap plus

deriving instance (MaxPlus (e n)) => MaxPlus (NameMapE c e n)
max = unionWithNameMapE max
plus = unionWithNameMapE plus

-- === Access ===

Expand Down
8 changes: 4 additions & 4 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ askVectorByteWidth :: TopVectorizeM i o Word32
askVectorByteWidth = TopVectorizeM $ SubstReaderT $ lift $ lift11 (fromLiftE <$> ask)

extendCommuteMap :: AtomName SimpIR o -> MonoidCommutes -> TopVectorizeM i o a -> TopVectorizeM i o a
extendCommuteMap name commutativity = local $ insertNameMap name commutativity
extendCommuteMap name commutativity = local $ insertNameMapE name $ LiftE commutativity

vectorizeLoopsDestBlock :: DestBlock i
-> TopVectorizeM i o (DestBlock o)
Expand Down Expand Up @@ -309,9 +309,9 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
safe (RWSEffect Writer (Var h)) = do
h' <- renameM $ atomVarName h
commuteMap <- ask
case lookupNameMap h' commuteMap of
Just Commutes -> return True
Just DoesNotCommute -> return False
case lookupNameMapE h' commuteMap of
Just (LiftE Commutes) -> return True
Just (LiftE DoesNotCommute) -> return False
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
safe _ = return False

Expand Down

0 comments on commit 3cbde4c

Please sign in to comment.