Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize PatchPair with Quant wrapper #456

Merged
merged 10 commits into from
Dec 10, 2024
Prev Previous commit
Next Next commit
implement PatchPair using new Quant datatype
danmatichuk committed Dec 10, 2024
commit ed87d7f16a4db71312e86af663497b4253e6d8c1
10 changes: 9 additions & 1 deletion src/Data/Parameterized/TotalMapF.hs
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ module Data.Parameterized.TotalMapF
, apply
, compose
, zip
, mapWithKey
, traverseWithKey
) where

import Prelude hiding ( zip )
@@ -55,6 +57,12 @@ newtype TotalMapF (a :: k -> Type) (b :: k -> Type) = TotalMapF (MapF a b)
instance TraversableF (TotalMapF a) where
traverseF f (TotalMapF tm) = TotalMapF <$> traverseF f tm

mapWithKey :: (forall x. a x -> b x -> c x) -> TotalMapF a b -> TotalMapF a c
mapWithKey f (TotalMapF m) = TotalMapF $ MapF.mapWithKey f m

traverseWithKey :: Applicative m => (forall x. a x -> b x -> m (c x)) -> TotalMapF a b -> m (TotalMapF a c)
traverseWithKey f (TotalMapF m) = TotalMapF <$> MapF.traverseWithKey f m

instance (TestEquality a, (forall x. (Eq (b x)))) => Eq (TotalMapF a b) where
m1 == m2 = all (\(MapF.Pair _ (PairF b1 b2)) -> b1 == b2) (zipToList m1 m2)

@@ -77,7 +85,7 @@ class HasTotalMapF a where
totalMapRepr :: forall a. (OrdF a, HasTotalMapF a) => TotalMapF a (Const ())
totalMapRepr = TotalMapF $ MapF.fromList (map (\(Some x) -> MapF.Pair x (Const ())) $ allValues @a)

apply :: OrdF a => TotalMapF a b -> (forall x. a x -> b x)
apply :: OrdF a => TotalMapF (a :: k -> Type) b -> (forall x. a x -> b x)
apply (TotalMapF m) k = case MapF.lookup k m of
Just v -> v
Nothing -> error "TotalMapF apply: internal failure. Likely 'HasTotalMapF' instance is incomplete."
112 changes: 106 additions & 6 deletions src/Data/Quant.hs
Original file line number Diff line number Diff line change
@@ -53,16 +53,25 @@ module Data.Quant
, ToQuant(..)
, HasReprK(..)
, ToMaybeQuant(..)
, pattern QuantToOne
, generateAll
, generateAllM
, pattern QuantAsAll
, pattern QuantAsOne
) where

import Data.Kind (Type)
import Data.Kind (Type)
import Data.Constraint

import Data.Functor.Const
import Data.Proxy
import Data.Parameterized.TraversableF
import Data.Parameterized.TraversableFC
import Data.Parameterized.Classes
import Data.Parameterized.Some
import qualified Data.Parameterized.TotalMapF as TMF
import Data.Parameterized.TotalMapF ( TotalMapF, HasTotalMapF )
import Data.Parameterized.WithRepr

-- | Wraps the kind 'k' with additional cases for existential and
-- universal quantification
@@ -72,12 +81,13 @@ type OneK = 'OneK
type ExistsK = 'ExistsK
type AllK = 'AllK

type KnownHasRepr (k0 :: k) = KnownRepr (ReprOf :: k -> Type) k0

-- | Similar to 'KnownRepr' and 'IsRepr' but defines a specific type 'ReprOf' that serves as the runtime representation of
-- the type parameter for values of type 'f k'
class (HasTotalMapF (ReprOf :: k -> Type), TestEquality (ReprOf :: k -> Type), OrdF (ReprOf :: k -> Type)) => HasReprK k where
class (HasTotalMapF (ReprOf :: k -> Type), TestEquality (ReprOf :: k -> Type), OrdF (ReprOf :: k -> Type), IsRepr (ReprOf :: k -> Type)) => HasReprK k where
type ReprOf :: k -> Type

-- we need this so that quantification is necessarily bounded in order to meaningfully compare universally quantified values

allReprs :: forall k. HasReprK k => TotalMapF (ReprOf :: k -> Type) (Const ())
allReprs = TMF.totalMapRepr @(ReprOf :: k -> Type)

@@ -96,6 +106,12 @@ data Quant (f :: k0 -> Type) (tp :: QuantK k0) where
QuantExists :: Quant f (OneK k) -> Quant f ExistsK
QuantAny :: Quant f AllK -> Quant f ExistsK

generateAll :: HasReprK k => (forall (x :: k). ReprOf x -> f x) -> Quant f AllK
generateAll f = QuantAll $ TMF.mapWithKey (\k _ -> f k) allReprs

generateAllM :: (HasReprK k, Applicative m) => (forall (x :: k). ReprOf x -> m (f x)) -> m (Quant f AllK)
generateAllM f = QuantAll <$> TMF.traverseWithKey (\k _ -> f k) allReprs

-- Drop the type information from a 'Quant' by making it existential instead.
toQuantExists :: Quant f tp1 -> Quant f ExistsK
toQuantExists x = case x of
@@ -121,6 +137,7 @@ pattern QuantSome x <- (fromQuantSome -> Just (Refl, Some x))

{-# COMPLETE QuantOne, QuantAll, QuantSome #-}


instance FunctorFC Quant where
fmapFC f = \case
QuantOne repr x -> QuantOne repr (f x)
@@ -139,6 +156,9 @@ instance forall k. HasReprK k => TraversableFC (Quant :: (k -> Type) -> QuantK k
QuantAll g -> QuantAll <$> traverseF f g
QuantSome x -> QuantSome <$> traverseFC f x




quantToRepr :: Quant f tp -> QuantRepr tp
quantToRepr = \case
QuantOne baserepr _ -> QuantOneRepr baserepr
@@ -150,6 +170,17 @@ data QuantRepr (tp :: QuantK k0) where
QuantAllRepr :: QuantRepr AllK
QuantSomeRepr :: QuantRepr ExistsK

instance KnownHasRepr x => KnownRepr (QuantRepr :: QuantK k0 -> Type) (OneK (x :: k0)) where
knownRepr = QuantOneRepr knownRepr

instance KnownRepr QuantRepr AllK where
knownRepr = QuantAllRepr

instance KnownRepr QuantRepr ExistsK where
knownRepr = QuantSomeRepr

instance IsRepr (ReprOf :: k -> Type) => IsRepr (QuantRepr :: QuantK k -> Type)

instance forall k. (HasReprK k) => TestEquality (QuantRepr :: QuantK k -> Type) where
testEquality (QuantOneRepr r1) (QuantOneRepr r2) = case testEquality r1 r2 of
Just Refl -> Just Refl
@@ -175,7 +206,7 @@ instance forall k. (HasReprK k) => OrdF (QuantRepr :: QuantK k -> Type) where
compareF QuantSomeRepr (QuantOneRepr{}) = GTF
compareF QuantSomeRepr QuantAllRepr = GTF

instance forall k f. (HasReprK k, (forall x. Ord (f x))) => TestEquality (Quant (f :: k -> Type)) where
instance forall k f. (HasReprK k, (forall x. Eq (f x))) => TestEquality (Quant (f :: k -> Type)) where
testEquality repr1 repr2 = case (repr1, repr2) of
(QuantOne baserepr1 x1, QuantOne baserepr2 x2) -> case testEquality baserepr1 baserepr2 of
Just Refl | x1 == x2 -> Just Refl
@@ -222,6 +253,15 @@ instance forall k f. (HasReprK k, (forall x. Ord (f x))) => OrdF (Quant (f :: k
(QuantAny{}, QuantAll{}) -> GTF
(QuantAny{}, QuantExists{}) -> GTF

instance (HasReprK k, forall x. Eq (f x)) => Eq (Quant (f :: k -> Type) tp) where
q1 == q2 = case testEquality q1 q2 of
Just Refl -> True
Nothing -> False

instance (HasReprK k, forall x. Ord (f x)) => Ord (Quant (f :: k -> Type) tp) where
compare q1 q2 = toOrdering $ compareF q1 q2


-- Defining which conversions are always possible
class ToQuant f (t1 :: QuantK k) (t2 :: QuantK k) where
toQuant :: f t1 -> QuantRepr t2 -> f t2
@@ -260,4 +300,64 @@ instance HasReprK k => ToMaybeQuant (Quant f) (tp1 :: QuantK k) (tp2 :: QuantK k
(QuantOne{}, QuantAllRepr) -> Nothing
-- in general we could consider types that themselves have defined conversions between
-- their type parameters (i.e nested Quants), but this level of generalization seems excessive without
-- good reason
-- good reason


data AsOneK (f :: QuantK k -> Type) (y :: k) where
AsOneK :: f (OneK y) -> AsOneK f y


class (Antecedent p c => c) => Implies p c where
type Antecedent p c :: Constraint

instance Implies (IsOneK AllK) c where
type Antecedent (IsOneK AllK) c = c

instance Implies (IsOneK ExistsK) c where
type Antecedent (IsOneK ExistsK) c = c

instance c => Implies (IsOneK (OneK k)) c where
type Antecedent (IsOneK (OneK k)) c = IsOneK (OneK k)

class (tp ~ (OneK (TheOneK tp))) => IsOneK tp where
type TheOneK tp :: k

instance IsOneK (OneK k) where
type TheOneK (OneK k) = k


asQuantOne :: forall k (x :: k) f tp. HasReprK k => ReprOf x -> Quant (f :: k -> Type) (tp :: QuantK k) -> Maybe (Dict (KnownRepr QuantRepr tp), Dict (Implies (IsOneK tp) (x ~ TheOneK tp)), ReprOf x, f x)
asQuantOne repr = \case
QuantOne repr' f | Just Refl <- testEquality repr' repr -> Just (withRepr (QuantOneRepr repr') $ Dict, Dict, repr, f)
QuantOne{} -> Nothing
QuantAll tm -> Just (Dict, Dict, repr, TMF.apply tm repr)
QuantExists x -> case asQuantOne repr x of
Just (Dict, _, _, x') -> Just (Dict, Dict, repr, x')
Nothing -> Nothing
QuantAny (QuantAll f) -> Just (Dict, Dict, repr, TMF.apply f repr)

-- | Cast a 'Quant' to a specific instance of 'x' if it contains it. Pattern match failure otherwise.
pattern QuantToOne :: forall {k} x f tp. (KnownHasRepr (x :: k), HasReprK k) => ( KnownRepr QuantRepr tp, Implies (IsOneK tp) (x ~ TheOneK tp)) => f x -> Quant f tp
pattern QuantToOne fx <- (asQuantOne (knownRepr :: ReprOf x) -> Just (Dict, Dict, _, fx))


-- | Project a total function from a 'Quant' if it is universally quantified.
pattern QuantAsAll :: forall {k} f tp. (HasReprK k) => () => (forall x. ReprOf x -> f x) -> Quant (f :: k -> Type) tp
pattern QuantAsAll f <- ((\l -> (case toMaybeQuant l QuantAllRepr of Just (QuantAll f) -> Just (TMF.apply f); _ -> Nothing) :: Maybe (forall (x :: k). ReprOf x -> f x) ) -> Just (f))

data QuantAsOneProof (f :: k -> Type) (tp :: QuantK k) where
QuantAsOneProof :: (KnownRepr QuantRepr tp, Implies (IsOneK tp) (x ~ TheOneK tp)) => ReprOf x -> f x -> QuantAsOneProof f tp


quantAsOne :: forall k f tp. HasReprK k => Quant (f :: k -> Type) (tp :: QuantK k) -> Maybe (QuantAsOneProof f tp)
quantAsOne q = case q of
QuantOne repr x-> withRepr repr $ Just (QuantAsOneProof repr x)
QuantExists q' -> case quantAsOne q' of
Just (QuantAsOneProof repr x) -> Just $ QuantAsOneProof repr x
Nothing -> Nothing
_ -> Nothing


-- | Project out the element of a singleton 'Quant'
pattern QuantAsOne :: forall {k} f tp. (HasReprK k) => forall x. (KnownRepr QuantRepr tp, Implies (IsOneK tp) (x ~ TheOneK tp)) => ReprOf x -> f x -> Quant (f :: k -> Type) tp
pattern QuantAsOne repr x <- (quantAsOne -> Just (QuantAsOneProof repr x))
10 changes: 10 additions & 0 deletions src/Pate/Binary.hs
Original file line number Diff line number Diff line change
@@ -39,6 +39,9 @@ where
import Data.Parameterized.WithRepr
import Data.Parameterized.Classes
import Data.Parameterized.Some
import qualified Data.Parameterized.TotalMapF as TMF
import qualified Data.Quant as Qu
import Data.Quant ( Quant, QuantK)
import qualified Prettyprinter as PP
import Pate.TraceTree

@@ -119,3 +122,10 @@ instance KnownRepr WhichBinaryRepr Patched where
type KnownBinary (bin :: WhichBinary) = KnownRepr WhichBinaryRepr bin

instance IsRepr WhichBinaryRepr

instance TMF.HasTotalMapF WhichBinaryRepr where
allValues = [Some OriginalRepr, Some PatchedRepr]

instance Qu.HasReprK WhichBinary where
type ReprOf = WhichBinaryRepr

6 changes: 3 additions & 3 deletions src/Pate/Equivalence.hs
Original file line number Diff line number Diff line change
@@ -526,8 +526,8 @@ eqDomPre ::
eqDomPre sym stO stP (eqCtxHDR -> hdr) eqDom = do
let
st = PPa.PatchPair stO stP
maxRegion = TF.fmapF (\st' -> Const $ unSE $ simMaxRegion st') st
stackBase = TF.fmapF (\st' -> Const $ unSE $ simStackBase st') st
maxRegion = PPa.map (\st' -> Const $ unSE $ simMaxRegion st') st
stackBase = PPa.map (\st' -> Const $ unSE $ simStackBase st') st

regsEq <- regDomRel hdr sym stO stP (PED.eqDomainRegisters eqDom)
maxRegionsEq <- mkNamedAsm sym (PED.eqDomainMaxRegion eqDom) (bindExprPair maxRegion)
@@ -563,7 +563,7 @@ eqDomPost sym stO stP eqCtx domPre domPost = do
st = PPa.PatchPair stO stP
hdr = eqCtxHDR eqCtx
stackRegion = eqCtxStackRegion eqCtx
maxRegion = TF.fmapF (\st' -> Const $ unSE $ simMaxRegion st') st
maxRegion = PPa.map (\st' -> Const $ unSE $ simMaxRegion st') st

regsEq <- regDomRel hdr sym stO stP (PED.eqDomainRegisters domPost)
stacksEq <- memDomPost sym (MemEqAtRegion stackRegion) stO stP (PED.eqDomainStackMemory domPre) (PED.eqDomainStackMemory domPost)
13 changes: 7 additions & 6 deletions src/Pate/Interactive/Render/Proof.hs
Original file line number Diff line number Diff line change
@@ -26,6 +26,7 @@ import qualified Data.Parameterized.Map as MapF
import qualified Data.Parameterized.NatRepr as PN
import Data.Parameterized.Some ( Some(..) )
import qualified Data.Parameterized.TraversableF as TF
import qualified Data.Parameterized.TraversableFC as TFC
import Data.Proxy (Proxy(..))
import qualified Data.Text as T
import qualified Data.Vector as DV
@@ -286,7 +287,7 @@ renderRegVal domain reg regOp =
PPa.PatchPairSingle{} -> Just prettySlotVal
_ -> Just prettySlotVal
where
vals = TF.fmapF (\(Const x) -> Const $ PFI.groundMacawValue x) $ PPr.slRegOpValues regOp
vals = TFC.fmapFC (\(Const x) -> Const $ PFI.groundMacawValue x) $ PPr.slRegOpValues regOp

ppDom =
case PFI.regInGroundDomain (PED.eqDomainRegisters domain) reg of
@@ -334,7 +335,7 @@ renderMemCellVal
-> Maybe (PP.Doc a)
renderMemCellVal domain cell memOp = do
guard (PG.groundValue $ PPr.slMemOpCond memOp)
let vals = TF.fmapF (\(Const x) -> Const $ PFI.groundBV x) $ PPr.slMemOpValues memOp
let vals = PPa.map (\(Const x) -> Const $ PFI.groundBV x) $ PPr.slMemOpValues memOp
let ppDom = case PFI.cellInGroundDomain domain cell of
True -> PP.emptyDoc
False -> PP.pretty "| Excluded"
@@ -365,7 +366,7 @@ renderIPs st
| (PG.groundValue $ PPr.slRegOpEquiv pcRegs) = PP.pretty (PPa.someC vals)
| otherwise = PPa.ppPatchPairC PP.pretty vals
where
vals = TF.fmapF (\(Const x) -> Const $ PFI.groundMacawValue x) (PPr.slRegOpValues pcRegs)
vals = PPa.map (\(Const x) -> Const $ PFI.groundMacawValue x) (PPr.slRegOpValues pcRegs)
pcRegs = PPr.slRegState st ^. MC.curIP

renderReturn
@@ -384,7 +385,7 @@ renderCounterexample
-> TP.UI TP.Element
renderCounterexample ineqRes' = PPr.withIneqResult ineqRes' $ \ineqRes ->
let
groundEnd = TF.fmapF (\(Const x) -> Const $ (PFI.groundBlockEnd (Proxy @arch)) x) $ PPr.slBlockExitCase (PPr.ineqSlice ineqRes)
groundEnd = PPa.map (\(Const x) -> Const $ (PFI.groundBlockEnd (Proxy @arch)) x) $ PPr.slBlockExitCase (PPr.ineqSlice ineqRes)
renderInequalityReason rsn =
case rsn of
PEE.InequivalentRegisters ->
@@ -397,11 +398,11 @@ renderCounterexample ineqRes' = PPr.withIneqResult ineqRes' $ \ineqRes ->
TP.string "The original and patched programs have generated invalid post states"

renderedContinuation = TP.column (catMaybes [ Just (text (PP.pretty "Next IP: " <> renderIPs (PPr.slBlockPostState (PPr.ineqSlice ineqRes))))
, renderReturn (TF.fmapF (\(Const x) -> Const $ PFI.grndBlockReturn x) groundEnd)
, renderReturn (PPa.map (\(Const x) -> Const $ PFI.grndBlockReturn x) groundEnd)
])
in
TP.grid [ [ renderInequalityReason (PPr.ineqReason ineqRes) ]
, [ text (PPa.ppPatchPairCEq (PP.pretty . PFI.ppExitCase) (TF.fmapF (\(Const x) -> Const $ PFI.grndBlockCase x) groundEnd)) ]
, [ text (PPa.ppPatchPairCEq (PP.pretty . PFI.ppExitCase) (PPa.map (\(Const x) -> Const $ PFI.grndBlockCase x) groundEnd)) ]
, [ TP.h2 #+ [TP.string "Initial states"] ]
, [ renderRegisterState (PPr.ineqPre ineqRes) (PPr.slRegState (PPr.slBlockPreState (PPr.ineqSlice ineqRes)))
, renderMemoryState (PPr.ineqPre ineqRes) (PPr.slMemState (PPr.slBlockPreState (PPr.ineqSlice ineqRes)))
2 changes: 1 addition & 1 deletion src/Pate/Location.hs
Original file line number Diff line number Diff line change
@@ -406,6 +406,6 @@ instance (LocationTraversable sym arch a, LocationTraversable sym arch b) =>

instance (forall bin. (LocationWitherable sym arch (f bin))) =>
LocationWitherable sym arch (PPa.PatchPair f) where
witherLocation sym pp f = TF.traverseF (\x -> witherLocation sym x f) pp
witherLocation sym pp f = PPa.traverse (\x -> witherLocation sym x f) pp


6 changes: 3 additions & 3 deletions src/Pate/Monad.hs
Original file line number Diff line number Diff line change
@@ -1531,7 +1531,7 @@ withPair :: PB.BlockPair arch -> EquivM sym arch a -> EquivM sym arch a
withPair pPair f = do
env <- CMR.ask
let env' = env { envParentBlocks = pPair:envParentBlocks env }
let entryPair = TF.fmapF (\b -> PB.functionEntryToConcreteBlock (PB.blockFunctionEntry b)) pPair
let entryPair = PPa.map (\b -> PB.functionEntryToConcreteBlock (PB.blockFunctionEntry b)) pPair
CMR.local (\_ -> env' & PME.envCtxL . PMC.currentFunc .~ entryPair) f

-- | Emit a trace event to the frontend
@@ -1543,7 +1543,7 @@ traceBlockPair
-> String
-> EquivM sym arch ()
traceBlockPair bp msg =
emitEvent (PE.ProofTraceEvent callStack (TF.fmapF (Const . PB.concreteAddress) bp) (T.pack msg))
emitEvent (PE.ProofTraceEvent callStack (PPa.map (Const . PB.concreteAddress) bp) (T.pack msg))

-- | Emit a trace event to the frontend
--
@@ -1554,7 +1554,7 @@ traceBundle
-> String
-> EquivM sym arch ()
traceBundle bundle msg = do
let bp = TF.fmapF (Const . PB.concreteAddress . simInBlock) (simIn bundle)
let bp = PPa.map (Const . PB.concreteAddress . simInBlock) (simIn bundle)
emitEvent (PE.ProofTraceEvent callStack bp (T.pack msg))

fnTrace ::
4 changes: 2 additions & 2 deletions src/Pate/Monad/PairGraph.hs
Original file line number Diff line number Diff line change
@@ -167,7 +167,7 @@ initializePairGraph pPairs = foldM (\x y -> initPair x y) emptyPairGraph pPairs
where
initPair :: PairGraph sym arch -> PB.FunPair arch -> EquivM sym arch (PairGraph sym arch)
initPair gr fnPair =
do let bPair = TF.fmapF PB.functionEntryToConcreteBlock fnPair
do let bPair = PPa.map PB.functionEntryToConcreteBlock fnPair
withPair bPair $ do
-- initial state of the pair graph: choose the universal domain that equates as much as possible
let node = GraphNode (rootEntry bPair)
@@ -205,7 +205,7 @@ initialDomainSpec (GraphNodeEntry blocks) = withTracing @"function_name" "initia
dom <- initialDomain
return (mempty, dom)
initialDomainSpec (GraphNodeReturn fPair) = withTracing @"function_name" "initialDomainSpec" $ do
let blocks = TF.fmapF PB.functionEntryToConcreteBlock fPair
let blocks = PPa.map PB.functionEntryToConcreteBlock fPair
withFreshVars blocks $ \_vars -> do
dom <- initialDomain
return (mempty, dom)
Loading