diff --git a/pate.cabal b/pate.cabal index e3a881c4..93efa7bd 100644 --- a/pate.cabal +++ b/pate.cabal @@ -94,6 +94,8 @@ library exposed-modules: Data.Macaw.CFGSlice, Data.RevMap, Data.UnwrapType, + Data.Parameterized.TotalMapF, + Data.Quant, Pate.Abort, Pate.Address, Pate.AssumptionSet, @@ -227,7 +229,7 @@ common shared-test -- just used for loading test modules interactively library pate-test-base import: shared-test - hs-source-dirs: tests, arch + hs-source-dirs: tests, arch, tests/types other-modules: Pate.AArch32, Pate.PPC build-depends: semmc-aarch32, macaw-aarch32, @@ -309,6 +311,13 @@ test-suite pate-test-solver what4, dismantle-arm-xml +test-suite pate-test-types + import: shared-test + type: exitcode-stdio-1.0 + main-is: TypesTestMain.hs + other-modules: QuantTest + hs-source-dirs: tests, tests/types + common shared-exec ghc-options: -Wall -Wcompat default-language: Haskell2010 diff --git a/src/Data/Parameterized/TotalMapF.hs b/src/Data/Parameterized/TotalMapF.hs new file mode 100644 index 00000000..9a3f4ccc --- /dev/null +++ b/src/Data/Parameterized/TotalMapF.hs @@ -0,0 +1,111 @@ +{-| +Module : Data.Parameterized.TotalMapF +Copyright : (c) Galois, Inc 2024 +Maintainer : Daniel Matichuk + +Reified total functions (maps) with testable equality and ordering. +Implemented as a wrapped Data.Parameterized.Map with a constrained interface. + +-} + +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralisedNewtypeDeriving #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE DefaultSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE QuantifiedConstraints #-} + +module Data.Parameterized.TotalMapF + ( + TotalMapF + , HasTotalMapF(..) + , totalMapRepr + , apply + , compose + , zip + , mapWithKey + , traverseWithKey + ) where + +import Prelude hiding ( zip ) +import qualified Data.List as List + +import Data.Kind (Type) +import Data.Functor.Const + +import Data.Parameterized.TraversableF +import Data.Parameterized.PairF +import Data.Parameterized.Classes +import Data.Parameterized.Some +import qualified Data.Parameterized.Map as MapF +import Data.Parameterized.Map ( MapF ) + +-- | A wrapped 'MapF' from 'a' to 'b' that covers all possible values of 'a' for all instances. +-- All values of 'a' is defined by 'allValues' from 'HasTotalMapF', which is trusted. +-- If 'allValues' is incomplete then the behavior of this datatype is undefined and +-- may raise runtime errors. +newtype TotalMapF (a :: k -> Type) (b :: k -> Type) = TotalMapF (MapF a b) + deriving (FoldableF, FunctorF, Show) + +instance TraversableF (TotalMapF a) where + traverseF f (TotalMapF tm) = TotalMapF <$> traverseF f tm + +mapWithKey :: (forall x. a x -> b x -> c x) -> TotalMapF a b -> TotalMapF a c +mapWithKey f (TotalMapF m) = TotalMapF $ MapF.mapWithKey f m + +traverseWithKey :: Applicative m => (forall x. a x -> b x -> m (c x)) -> TotalMapF a b -> m (TotalMapF a c) +traverseWithKey f (TotalMapF m) = TotalMapF <$> MapF.traverseWithKey f m + +instance (TestEquality a, (forall x. (Eq (b x)))) => Eq (TotalMapF a b) where + m1 == m2 = all (\(MapF.Pair _ (PairF b1 b2)) -> b1 == b2) (zipToList m1 m2) + +instance (OrdF a, (forall x. (Ord (b x)))) => Ord (TotalMapF a b) where + compare m1 m2 = compareZipped (map (\(MapF.Pair _ v) -> Some v) $ zipToList m1 m2) + +compareZipped :: (forall x. (Ord (a x))) => [Some (PairF a a)] -> Ordering +compareZipped (Some (PairF x1 x2):xs) = case compare x1 x2 of + EQ -> compareZipped xs + LT -> LT + GT -> GT +compareZipped [] = EQ + +class HasTotalMapF a where + -- | A list of all possible values for this type (for all possible instances). + -- TODO: Unclear how this will behave if defined for infinite types via a lazy list + allValues :: [Some a] + +-- | Canonical total map for a given type. Use FunctorF instance to create maps to other types. +totalMapRepr :: forall a. (OrdF a, HasTotalMapF a) => TotalMapF a (Const ()) +totalMapRepr = TotalMapF $ MapF.fromList (map (\(Some x) -> MapF.Pair x (Const ())) $ allValues @a) + +apply :: OrdF a => TotalMapF (a :: k -> Type) b -> (forall x. a x -> b x) +apply (TotalMapF m) k = case MapF.lookup k m of + Just v -> v + Nothing -> error "TotalMapF apply: internal failure. Likely 'HasTotalMapF' instance is incomplete." + +compose :: (OrdF a, OrdF b) => TotalMapF a b -> TotalMapF b c -> TotalMapF a c +compose atoB btoC = fmapF (apply btoC) atoB + +-- | Same as 'zip' but skips re-building the map +zipToList :: TestEquality a => TotalMapF a b -> TotalMapF a c -> [MapF.Pair a (PairF b c)] +zipToList m1 m2 = + let + m1' = toList m1 + m2' = toList m2 + err = error "TotalMapF zip: internal failure. Likely 'HasTotalMapF' instance is incomplete." + in case length m1' == length m2' of + True -> map (\(MapF.Pair a1 b, MapF.Pair a2 c) -> case testEquality a1 a2 of Just Refl -> MapF.Pair a1 (PairF b c); Nothing -> err) (List.zip m1' m2') + False -> err + +zip :: OrdF a => TotalMapF a b -> TotalMapF a c -> TotalMapF a (PairF b c) +zip m1 m2 = TotalMapF (MapF.fromList $ zipToList m1 m2) + +toList :: TotalMapF a b -> [MapF.Pair a b] +toList (TotalMapF m) = MapF.toList m diff --git a/src/Data/Quant.hs b/src/Data/Quant.hs new file mode 100644 index 00000000..f83ebc2a --- /dev/null +++ b/src/Data/Quant.hs @@ -0,0 +1,537 @@ +{-| +Module : Data.Quant +Copyright : (c) Galois, Inc 2024 +Maintainer : Daniel Matichuk + +A container that is used to conveniently define datatypes that +are generalized over concrete, existential and universal quantification. + +-} + + +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE StandaloneKindSignatures #-} + +{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# LANGUAGE UndecidableSuperClasses #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE DefaultSignatures #-} + +module Data.Quant + ( + Quant(..) + , type QuantK + , type OneK + , type ExistsK + , type AllK + , map + , traverse + , mapWithRepr + , traverseWithRepr + , pattern QuantSome + , toQuantExists + , quantToRepr + , QuantRepr(..) + , QuantConversion(..) + , QuantConvertible(..) + , QuantCoercion(..) + , QuantCoercible(..) + , HasReprK(..) + , pattern QuantToOne + , generateAll + , generateAllM + , pattern All + , pattern Single + , viewQuantEach + , pattern QuantEach + , AsSingle(..) + ) where + +import Prelude hiding (map, traverse) + +import Data.Kind (Type) +import Data.Constraint + +import Data.Functor.Const +import Data.Proxy +import Data.Parameterized.TraversableF +import Data.Parameterized.TraversableFC +import Data.Parameterized.Classes +import Data.Parameterized.Some +import qualified Data.Parameterized.TotalMapF as TMF +import Data.Parameterized.TotalMapF ( TotalMapF, HasTotalMapF ) +import Data.Parameterized.WithRepr + +-- | Wraps the kind 'k' with additional cases for existential and +-- universal quantification +data QuantK k = OneK k | ExistsK | AllK + +type OneK = 'OneK +type ExistsK = 'ExistsK +type AllK = 'AllK + +type KnownHasRepr (k0 :: k) = KnownRepr (ReprOf :: k -> Type) k0 + +-- | Similar to 'KnownRepr' and 'IsRepr' but defines a specific type 'ReprOf' that serves as the runtime representation of +-- the type parameter for values of type 'f k' +class (HasTotalMapF (ReprOf :: k -> Type), TestEquality (ReprOf :: k -> Type), OrdF (ReprOf :: k -> Type), IsRepr (ReprOf :: k -> Type)) => HasReprK k where + type ReprOf :: k -> Type + +allReprs :: forall k. HasReprK k => TotalMapF (ReprOf :: k -> Type) (Const ()) +allReprs = TMF.totalMapRepr @(ReprOf :: k -> Type) + +-- | Wraps a kind 'k -> Type' to represent the following possible cases: +-- * a single value of type 'f k' (i.e. f x ~ Quant f (OneK x)) +-- * all possible values of type 'f k' (i.e. (forall k. f k) ~ Quant f AllK) +-- * existentially quantified over the above two cases (i.e. Some f ~ Quant f ExistsK ~ Some (Quant f)) +-- By universally quantifying types and functions over 'Quant k' we can implicitly handle all 3 of the +-- above cases, rather than requiring individual implementations for each. +data Quant (f :: k0 -> Type) (tp :: QuantK k0) where + QuantOne :: ReprOf k -> f k -> Quant f (OneK k) + -- ^ a single value of type 'f k' + QuantAll :: TotalMapF ReprOf f -> Quant f AllK + + -- the above two cases, but existentially wrapped + QuantExists :: Quant f (OneK k) -> Quant f ExistsK + QuantAny :: Quant f AllK -> Quant f ExistsK + +generateAll :: HasReprK k => (forall (x :: k). ReprOf x -> f x) -> Quant f AllK +generateAll f = QuantAll $ TMF.mapWithKey (\k _ -> f k) allReprs + +generateAllM :: (HasReprK k, Applicative m) => (forall (x :: k). ReprOf x -> m (f x)) -> m (Quant f AllK) +generateAllM f = QuantAll <$> TMF.traverseWithKey (\k _ -> f k) allReprs + +-- Drop the type information from a 'Quant' by making it existential instead. +toQuantExists :: Quant f tp1 -> Quant f ExistsK +toQuantExists x = case x of + QuantOne{} -> QuantExists x + QuantAll{} -> QuantAny x + QuantExists y -> toQuantExists y + QuantAny y -> toQuantExists y + + +fromQuantSome :: Quant f tp -> Maybe (tp :~: ExistsK, Some (Quant f)) +fromQuantSome x = case x of + QuantExists y -> Just (Refl,Some y) + QuantAny y -> Just (Refl,Some y) + _ -> Nothing + +-- | A more convenient interface for handling existential cases, which +-- doesn't distinguish between universal or concrete for the wrapped +-- Quant. +pattern QuantSome :: () => (tp2 ~ ExistsK) => Quant f tp1 -> Quant f tp2 +pattern QuantSome x <- (fromQuantSome -> Just (Refl, Some x)) + where + QuantSome x = toQuantExists x + +{-# COMPLETE QuantOne, QuantAll, QuantSome #-} + + +instance FunctorFC Quant where + fmapFC f = \case + QuantOne repr x -> QuantOne repr (f x) + QuantAll g -> QuantAll (fmapF f g) + QuantSome x -> QuantSome (fmapFC f x) + +instance forall k. HasReprK k => FoldableFC (Quant :: (k -> Type) -> QuantK k -> Type) where + foldrFC f b = \case + QuantOne _ x -> f x b + QuantAll g -> foldrF f b g + QuantSome x -> foldrFC f b x + +instance forall k. HasReprK k => TraversableFC (Quant :: (k -> Type) -> QuantK k -> Type) where + traverseFC f = \case + QuantOne repr x -> QuantOne <$> pure repr <*> f x + QuantAll g -> QuantAll <$> traverseF f g + QuantSome x -> QuantSome <$> traverseFC f x + +map :: (forall x. f x -> g x) -> Quant f tp -> Quant g tp +map = fmapFC + +mapWithRepr :: (forall (x :: k). ReprOf x -> f x -> g x) -> Quant f tp -> Quant g tp +mapWithRepr f = \case + QuantOne repr x -> QuantOne repr $ f repr x + QuantAll tm -> QuantAll $ TMF.mapWithKey f tm + QuantSome x -> QuantSome $ mapWithRepr f x + +traverse :: (HasReprK k, Applicative m) => (forall (x :: k). f x -> m (g x)) -> Quant f tp -> m (Quant g tp) +traverse = traverseFC + +traverseWithRepr :: (HasReprK k, Applicative m) => (forall (x :: k). ReprOf x -> f x -> m (g x)) -> Quant f tp -> m (Quant g tp) +traverseWithRepr f = \case + QuantOne repr x -> QuantOne <$> pure repr <*> f repr x + QuantAll tm -> QuantAll <$> TMF.traverseWithKey f tm + QuantSome x -> QuantSome <$> traverseWithRepr f x + +quantToRepr :: Quant f tp -> QuantRepr tp +quantToRepr = \case + QuantOne baserepr _ -> QuantOneRepr baserepr + QuantAll{} -> QuantAllRepr + QuantSome{} -> QuantSomeRepr + +data QuantRepr (tp :: QuantK k0) where + QuantOneRepr :: ReprOf k -> QuantRepr (OneK k) + QuantAllRepr :: QuantRepr AllK + QuantSomeRepr :: QuantRepr ExistsK + +instance KnownHasRepr x => KnownRepr (QuantRepr :: QuantK k0 -> Type) (OneK (x :: k0)) where + knownRepr = QuantOneRepr knownRepr + +instance KnownRepr QuantRepr AllK where + knownRepr = QuantAllRepr + +instance KnownRepr QuantRepr ExistsK where + knownRepr = QuantSomeRepr + +instance IsRepr (ReprOf :: k -> Type) => IsRepr (QuantRepr :: QuantK k -> Type) + +instance forall k. (HasReprK k) => TestEquality (QuantRepr :: QuantK k -> Type) where + testEquality (QuantOneRepr r1) (QuantOneRepr r2) = case testEquality r1 r2 of + Just Refl -> Just Refl + Nothing -> Nothing + testEquality QuantAllRepr QuantAllRepr = Just Refl + testEquality QuantSomeRepr QuantSomeRepr = Just Refl + testEquality _ _ = Nothing + +instance forall k. (HasReprK k) => OrdF (QuantRepr :: QuantK k -> Type) where + compareF (QuantOneRepr r1) (QuantOneRepr r2) = case compareF r1 r2 of + EQF -> EQF + LTF -> LTF + GTF -> GTF + compareF QuantAllRepr QuantAllRepr = EQF + compareF QuantSomeRepr QuantSomeRepr = EQF + + compareF (QuantOneRepr{}) QuantAllRepr = LTF + compareF (QuantOneRepr{}) QuantSomeRepr = LTF + + compareF QuantAllRepr (QuantOneRepr{}) = GTF + compareF QuantAllRepr QuantSomeRepr = LTF + + compareF QuantSomeRepr (QuantOneRepr{}) = GTF + compareF QuantSomeRepr QuantAllRepr = GTF + +instance forall k f. (HasReprK k, (forall x. Eq (f x))) => TestEquality (Quant (f :: k -> Type)) where + testEquality repr1 repr2 = case (repr1, repr2) of + (QuantOne baserepr1 x1, QuantOne baserepr2 x2) -> case testEquality baserepr1 baserepr2 of + Just Refl | x1 == x2 -> Just Refl + _ -> Nothing + (QuantAll g1, QuantAll g2) -> case g1 == g2 of + True -> Just Refl + False -> Nothing + (QuantExists x1, QuantExists x2) -> case testEquality x1 x2 of + Just Refl -> Just Refl + Nothing -> Nothing + (QuantAny x1, QuantAny x2) -> case testEquality x1 x2 of + Just Refl -> Just Refl + Nothing -> Nothing + _ -> Nothing + + +instance forall k f. (HasReprK k, (forall x. Ord (f x))) => OrdF (Quant (f :: k -> Type)) where + compareF repr1 repr2 = case (repr1, repr2) of + (QuantOne baserepr1 x1, QuantOne baserepr2 x2) -> lexCompareF baserepr1 baserepr2 $ fromOrdering (compare x1 x2) + (QuantAll g1, QuantAll g2) -> fromOrdering (compare g1 g2) + (QuantExists x1, QuantExists x2) -> case compareF x1 x2 of + LTF -> LTF + GTF -> GTF + EQF -> EQF + (QuantAny x1, QuantAny x2) -> case compareF x1 x2 of + LTF -> LTF + GTF -> GTF + EQF -> EQF + + -- based on constructor ordering + (QuantOne{}, QuantAll{}) -> LTF + (QuantOne{}, QuantExists{}) -> LTF + (QuantOne{}, QuantAny{}) -> LTF + + (QuantAll{}, QuantOne{}) -> GTF + (QuantAll{}, QuantExists{}) -> LTF + (QuantAll{}, QuantAny{}) -> LTF + + (QuantExists{}, QuantOne{}) -> GTF + (QuantExists{}, QuantAll{}) -> GTF + (QuantExists{}, QuantAny{}) -> LTF + + (QuantAny{}, QuantOne{}) -> GTF + (QuantAny{}, QuantAll{}) -> GTF + (QuantAny{}, QuantExists{}) -> GTF + +instance (HasReprK k, forall x. Eq (f x)) => Eq (Quant (f :: k -> Type) tp) where + q1 == q2 = case testEquality q1 q2 of + Just Refl -> True + Nothing -> False + +instance (HasReprK k, forall x. Ord (f x)) => Ord (Quant (f :: k -> Type) tp) where + compare q1 q2 = toOrdering $ compareF q1 q2 + +data QuantCoercion (t1 :: QuantK k) (t2 :: QuantK k) where + CoerceAllToOne :: ReprOf x -> QuantCoercion AllK (OneK x) + CoerceAllToExists :: QuantCoercion AllK ExistsK + CoerceOneToExists :: QuantCoercion (OneK x) ExistsK + CoerceRefl :: QuantCoercion x x + +class QuantCoercible (f :: QuantK k -> Type) where + applyQuantCoercion :: forall t1 t2. HasReprK k => QuantCoercion t1 t2 -> f t1 -> f t2 + applyQuantCoercion qc f1 = withRepr qc $ coerceQuant f1 + + coerceQuant :: forall t1 t2. (HasReprK k, KnownCoercion t1 t2) => f t1 -> f t2 + coerceQuant = applyQuantCoercion knownRepr + +instance HasReprK k => IsRepr (QuantCoercion (t1 :: QuantK k)) where + withRepr x f = case x of + CoerceAllToOne repr -> withRepr repr $ f + CoerceAllToExists -> f + CoerceOneToExists -> f + CoerceRefl -> f + +instance QuantCoercible (Quant (f :: k -> Type)) where + applyQuantCoercion qc q = case (qc, q) of + (CoerceAllToOne repr, QuantAll f) -> QuantOne repr (TMF.apply f repr) + (CoerceAllToExists, QuantAll{}) -> QuantAny q + (CoerceOneToExists, QuantOne{}) -> QuantExists q + (CoerceRefl, _) -> q + +type KnownCoercion (tp1 :: QuantK k) (tp2 :: QuantK k) = KnownRepr (QuantCoercion tp1) tp2 + + +instance (KnownRepr (ReprOf :: k -> Type) (x :: k)) => KnownRepr (QuantCoercion AllK) (OneK x) where + knownRepr = CoerceAllToOne knownRepr + +instance KnownRepr (QuantCoercion AllK) ExistsK where + knownRepr = CoerceAllToExists + +instance KnownRepr (QuantCoercion (OneK x)) ExistsK where + knownRepr = CoerceOneToExists + +instance KnownRepr (QuantCoercion x) x where + knownRepr = CoerceRefl + + +data QuantConversion (t1 :: QuantK k) (t2 :: QuantK k) where + ConvertWithCoerce :: QuantCoercion t1 t2 -> QuantConversion t1 t2 + ConvertExistsToAll :: QuantConversion ExistsK AllK + ConvertExistsToOne :: ReprOf x -> QuantConversion ExistsK (OneK x) + +instance HasReprK k => IsRepr (QuantConversion (t1 :: QuantK k)) where + withRepr x f = case x of + ConvertWithCoerce y -> case y of + CoerceAllToOne repr -> withRepr repr $ f + CoerceAllToExists -> f + CoerceOneToExists -> f + CoerceRefl -> f + ConvertExistsToAll -> f + ConvertExistsToOne repr -> withRepr repr $ f + +class QuantConvertible (f :: QuantK k -> Type) where + applyQuantConversion :: forall t1 t2. HasReprK k => QuantConversion t1 t2 -> f t1 -> Maybe (f t2) + applyQuantConversion qc f1 = withRepr qc $ convertQuant f1 + + convertQuant :: forall t1 t2. (HasReprK k, KnownConversion t1 t2) => f t1 -> Maybe (f t2) + convertQuant = applyQuantConversion knownRepr + +type KnownConversion (tp1 :: QuantK k) (tp2 :: QuantK k) = KnownRepr (QuantConversion tp1) tp2 + +instance (KnownRepr (ReprOf :: k -> Type) (x :: k)) => KnownRepr (QuantConversion AllK) (OneK x) where + knownRepr = ConvertWithCoerce knownRepr + +instance KnownRepr (QuantConversion AllK) ExistsK where + knownRepr = ConvertWithCoerce knownRepr + +instance KnownRepr (QuantConversion (OneK x)) ExistsK where + knownRepr = ConvertWithCoerce knownRepr + +instance KnownRepr (QuantConversion x) x where + knownRepr = ConvertWithCoerce knownRepr + +instance KnownRepr (QuantConversion ExistsK) AllK where + knownRepr = ConvertExistsToAll + +instance (KnownRepr (ReprOf :: k -> Type) (x :: k)) => KnownRepr (QuantConversion ExistsK) (OneK x) where + knownRepr = ConvertExistsToOne knownRepr + + +instance QuantConvertible (Quant (f :: k -> Type)) where + applyQuantConversion qc q = case (qc, q) of + (ConvertWithCoerce qc', _) -> Just (applyQuantCoercion qc' q) + (ConvertExistsToAll, QuantAny q') -> Just q' + (ConvertExistsToAll, QuantExists{}) -> Nothing + (ConvertExistsToOne repr, QuantAny q') -> Just (applyQuantCoercion (CoerceAllToOne repr) q') + (ConvertExistsToOne repr, QuantExists q'@(QuantOne repr' _)) -> case testEquality repr repr' of + Just Refl -> Just q' + Nothing -> Nothing + +type family TheOneK (tp :: QuantK k) :: k where + TheOneK (OneK k) = k + +type family IfIsOneK (tp :: QuantK k) (c :: Constraint) :: Constraint where + IfIsOneK (OneK k) c = c + IfIsOneK AllK c = () + IfIsOneK ExistsK c = () + +asQuantOne :: forall k (x :: k) f tp. HasReprK k => ReprOf x -> Quant (f :: k -> Type) (tp :: QuantK k) -> Maybe (Dict (KnownRepr QuantRepr tp), Dict (IfIsOneK tp (x ~ TheOneK tp)), ReprOf x, f x) +asQuantOne repr = \case + QuantOne repr' f | Just Refl <- testEquality repr' repr -> Just (withRepr (QuantOneRepr repr') $ Dict, Dict, repr, f) + QuantOne{} -> Nothing + QuantAll tm -> Just (Dict, Dict, repr, TMF.apply tm repr) + QuantExists x -> case asQuantOne repr x of + Just (Dict, _, _, x') -> Just (Dict, Dict, repr, x') + Nothing -> Nothing + QuantAny (QuantAll f) -> Just (Dict, Dict, repr, TMF.apply f repr) + +-- | Cast a 'Quant' to a specific instance of 'x' if it contains it. Pattern match failure otherwise. +pattern QuantToOne :: forall {k} x f tp. (KnownHasRepr (x :: k), HasReprK k) => ( KnownRepr QuantRepr tp, IfIsOneK tp (x ~ TheOneK tp)) => f x -> Quant f tp +pattern QuantToOne fx <- (asQuantOne (knownRepr :: ReprOf x) -> Just (Dict, Dict, _, fx)) + + +data ExistsOrCases (tp1 :: QuantK k) (tp2 :: QuantK k) where + ExistsOrRefl :: ExistsOrCases tp tp + ExistsOrExists :: ExistsOrCases ExistsK tp + +type family IsExistsOrConstraint (tp1 :: QuantK k) (tp2 :: QuantK k) :: Constraint + +class IsExistsOrConstraint tp1 tp2 => IsExistsOr (tp1 :: QuantK k) (tp2 :: QuantK k) where + isExistsOr :: ExistsOrCases tp1 tp2 + +type instance IsExistsOrConstraint (OneK x) tp = ((OneK x) ~ tp) +type instance IsExistsOrConstraint (AllK :: QuantK k) tp = ((AllK :: QuantK k) ~ tp) + +instance IsExistsOr (OneK x) (OneK x) where + isExistsOr = ExistsOrRefl + +instance IsExistsOr AllK AllK where + isExistsOr = ExistsOrRefl + +instance IsExistsOr ExistsK ExistsK where + isExistsOr = ExistsOrRefl + +type instance IsExistsOrConstraint ExistsK x = () + +instance IsExistsOr ExistsK (OneK k) where + isExistsOr = ExistsOrExists + +instance IsExistsOr ExistsK AllK where + isExistsOr = ExistsOrExists + +data QuantAsAllProof (f :: k -> Type) (tp :: QuantK k) where + QuantAsAllProof :: (IsExistsOr tp AllK) => (forall x. ReprOf x -> f x) -> QuantAsAllProof f tp + +quantAsAll :: HasReprK k => Quant (f :: k -> Type) tp -> Maybe (QuantAsAllProof f tp) +quantAsAll q = case q of + QuantOne{} -> Nothing + QuantAll f -> Just (QuantAsAllProof (TMF.apply f)) + QuantSome q' -> case quantAsAll q' of + Just (QuantAsAllProof f) -> Just $ QuantAsAllProof f + Nothing -> Nothing + +-- | Pattern for creating or matching a universally quantified 'Quant', generalized over the existential cases +pattern All :: forall {k} f tp. (HasReprK k) => (IsExistsOr tp AllK) => (forall x. ReprOf x -> f x) -> Quant (f :: k -> Type) tp +pattern All f <- (quantAsAll -> Just (QuantAsAllProof f)) + where + All f = case (isExistsOr :: ExistsOrCases tp AllK) of + ExistsOrExists -> QuantAny (All f) + ExistsOrRefl -> QuantAll (TMF.mapWithKey (\repr _ -> f repr) (allReprs @k)) + +data QuantAsOneProof (f :: k -> Type) (tp :: QuantK k) where + QuantAsOneProof :: (IsExistsOr tp (OneK x), IfIsOneK tp (x ~ TheOneK tp)) => ReprOf x -> f x -> QuantAsOneProof f tp + +quantAsOne :: forall k f tp. HasReprK k => Quant (f :: k -> Type) (tp :: QuantK k) -> Maybe (QuantAsOneProof f tp) +quantAsOne q = case q of + QuantOne repr x-> withRepr repr $ Just (QuantAsOneProof repr x) + QuantExists q' -> case quantAsOne q' of + Just (QuantAsOneProof repr x) -> Just $ QuantAsOneProof repr x + Nothing -> Nothing + _ -> Nothing + +existsOrCases :: forall tp tp' a. IsExistsOr tp tp' => (tp ~ ExistsK => a) -> (tp ~ tp' => a) -> a +existsOrCases f g = case (isExistsOr :: ExistsOrCases tp tp') of + ExistsOrExists -> f + ExistsOrRefl -> g + +-- | Pattern for creating or matching a singleton 'Quant', generalized over the existential cases +pattern Single :: forall {k} f tp. (HasReprK k) => forall x. (IsExistsOr tp (OneK x), IfIsOneK tp (x ~ TheOneK tp)) => ReprOf x -> f x -> Quant (f :: k -> Type) tp +pattern Single repr x <- (quantAsOne -> Just (QuantAsOneProof repr x)) + where + Single (repr :: ReprOf x) x = existsOrCases @tp @(OneK x) (QuantExists (Single repr x)) (QuantOne repr x) + + +{-# COMPLETE Single, All #-} +{-# COMPLETE Single, QuantAll, QuantAny #-} +{-# COMPLETE Single, QuantAll, QuantSome #-} + +{-# COMPLETE All, QuantOne, QuantExists #-} +{-# COMPLETE All, QuantOne, QuantSome #-} + +newtype AsSingle (f :: QuantK k -> Type) (y :: k) where + AsSingle :: f (OneK y) -> AsSingle f y + +deriving instance Eq (f (OneK x)) => Eq ((AsSingle f) x) +deriving instance Ord (f (OneK x)) => Ord ((AsSingle f) x) +deriving instance Show (f (OneK x)) => Show ((AsSingle f) x) + +instance TestEquality f => TestEquality (AsSingle f) where + testEquality (AsSingle x) (AsSingle y) = case testEquality x y of + Just Refl -> Just Refl + Nothing -> Nothing + +instance OrdF f => OrdF (AsSingle f) where + compareF (AsSingle x) (AsSingle y) = case compareF x y of + EQF -> EQF + LTF -> LTF + GTF -> GTF + +instance forall f. ShowF f => ShowF (AsSingle f) where + showF (AsSingle x) = showF x + withShow _ (_ :: q tp) f = withShow (Proxy :: Proxy f) (Proxy :: Proxy (OneK tp)) f + +type QuantEach (f :: QuantK k -> Type) = Quant (AsSingle f) AllK + +viewQuantEach :: HasReprK k => QuantEach f -> (forall (x :: k). ReprOf x -> f (OneK x)) +viewQuantEach (QuantAll f) = \r -> case TMF.apply f r of AsSingle x -> x + +viewQuantEach' :: HasReprK k => Quant (AsSingle f) tp -> Maybe (Dict (IsExistsOr tp AllK), forall (x :: k). ReprOf x -> f (OneK x)) +viewQuantEach' q = case q of + QuantOne{} -> Nothing + QuantAll f -> Just (Dict, \r -> case TMF.apply f r of AsSingle x -> x) + QuantSome q' -> case viewQuantEach' q' of + Just (Dict, g) -> Just (Dict, g) + Nothing -> Nothing + +pattern QuantEach :: forall {k} f tp. (HasReprK k) => (IsExistsOr tp AllK) => (forall (x :: k). ReprOf x -> f (OneK x)) -> Quant (AsSingle f) tp +pattern QuantEach f <- (viewQuantEach' -> Just (Dict, f)) + where + QuantEach f = existsOrCases @tp @AllK (QuantAny (QuantEach f)) (QuantAll (TMF.mapWithKey (\r _ -> AsSingle (f r)) (allReprs @k))) + +{-# COMPLETE QuantEach, Single #-} + +_testQuantEach :: forall {k} f tp. HasReprK k => Quant (AsSingle (f :: QuantK k -> Type)) tp -> () +_testQuantEach = \case + QuantEach (_f :: forall (x :: k). ReprOf x -> f (OneK x)) -> () + Single (_repr :: ReprOf (x :: k)) (AsSingle (_x :: f (OneK x))) -> () + +_testQuantEach1 :: HasReprK k => Quant (AsSingle (f :: QuantK k -> Type)) AllK -> () +_testQuantEach1 = \case + QuantEach (_f :: forall (x :: k). ReprOf x -> f (OneK x)) -> () + -- complete match, since Single has an unsolvable constraint \ No newline at end of file diff --git a/src/Pate/Binary.hs b/src/Pate/Binary.hs index 1f03307c..a958e9fc 100644 --- a/src/Pate/Binary.hs +++ b/src/Pate/Binary.hs @@ -19,11 +19,13 @@ {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Pate.Binary ( type WhichBinary + , type BinaryPair , KnownBinary , Original , Patched @@ -33,14 +35,29 @@ module Pate.Binary , flipRepr , short , otherInvolutive + , ppBinaryPair + , ppBinaryPairEq + , ppBinaryPair' + , w4SerializePair ) where +import Data.Kind ( Type ) +import Control.Applicative.Alternative ( (<|>) ) + import Data.Parameterized.WithRepr import Data.Parameterized.Classes import Data.Parameterized.Some +import qualified Data.Parameterized.TotalMapF as TMF +import qualified Data.Aeson as JSON + +import qualified Data.Quant as Qu +import Data.Quant ( Quant, QuantK) import qualified Prettyprinter as PP import Pate.TraceTree +import qualified What4.JSON as W4S +import What4.JSON ( (.:) ) +import qualified Pate.ExprMappable as PEM -- | A type-level tag describing whether the data value is from an original binary or a patched binary data WhichBinary = Original | Patched deriving (Bounded, Enum, Eq, Ord, Read, Show) @@ -119,3 +136,96 @@ instance KnownRepr WhichBinaryRepr Patched where type KnownBinary (bin :: WhichBinary) = KnownRepr WhichBinaryRepr bin instance IsRepr WhichBinaryRepr + +instance TMF.HasTotalMapF WhichBinaryRepr where + allValues = [Some OriginalRepr, Some PatchedRepr] + +instance Qu.HasReprK WhichBinary where + type ReprOf = WhichBinaryRepr + +type BinaryPair f = Qu.Quant (f :: WhichBinary -> Type) + +jsonQuant :: + (forall bin. f bin -> JSON.Value) -> BinaryPair f (qbin :: QuantK WhichBinary) -> JSON.Value +jsonQuant f q = case q of + Qu.All g -> JSON.object [ "original" JSON..= f (g OriginalRepr), "patched" JSON..= f (g PatchedRepr)] + Qu.Single bin x -> case bin of + OriginalRepr -> JSON.object [ "original" JSON..= f x ] + PatchedRepr -> JSON.object [ "patched" JSON..= f x ] + + +instance (forall bin. JSON.ToJSON (f bin)) => JSON.ToJSON (Qu.Quant f (qbin :: QuantK WhichBinary)) where + toJSON p = jsonQuant (JSON.toJSON) p + + +w4SerializePair :: BinaryPair f qbin -> (forall bin. f bin -> W4S.W4S sym JSON.Value) -> W4S.W4S sym JSON.Value +w4SerializePair bpair f = case bpair of + Qu.All g -> do + o_v <- f (g OriginalRepr) + p_v <- f (g PatchedRepr) + return $ JSON.object ["original" JSON..= o_v, "patched" JSON..= p_v] + Qu.Single bin x -> case bin of + OriginalRepr -> do + o_v <- f x + return $ JSON.object ["original" JSON..= o_v] + PatchedRepr -> do + p_v <- f x + return $ JSON.object ["patched" JSON..= p_v] + +instance W4S.W4SerializableF sym f => W4S.W4Serializable sym (Qu.Quant f (qbin :: QuantK WhichBinary)) where + w4Serialize ppair = w4SerializePair ppair W4S.w4SerializeF + + +instance forall f sym. (forall bin. KnownRepr WhichBinaryRepr bin => W4S.W4Deserializable sym (f bin)) => W4S.W4Deserializable sym (Quant f Qu.ExistsK) where + w4Deserialize_ v = do + JSON.Object o <- return v + let + case_pair = do + (vo :: f Original) <- o .: "original" + (vp :: f Patched) <- o .: "patched" + return $ Qu.All $ \case OriginalRepr -> vo; PatchedRepr -> vp + case_orig = do + (vo :: f Original) <- o .: "original" + return $ Qu.Single OriginalRepr vo + case_patched = do + (vp :: f Patched) <- o .: "patched" + return $ Qu.Single PatchedRepr vp + case_pair <|> case_orig <|> case_patched + +ppBinaryPair :: (forall bin. tp bin -> PP.Doc a) -> BinaryPair tp qbin -> PP.Doc a +ppBinaryPair f (Qu.All g) = f (g OriginalRepr) PP.<+> "(original) vs." PP.<+> f (g PatchedRepr) PP.<+> "(patched)" +ppBinaryPair f (Qu.Single OriginalRepr a1) = f a1 PP.<+> "(original)" +ppBinaryPair f (Qu.Single PatchedRepr a1) = f a1 PP.<+> "(patched)" + +ppBinaryPair' :: (forall bin. tp bin -> PP.Doc a) -> BinaryPair tp qbin -> PP.Doc a +ppBinaryPair' f pPair = ppBinaryPairEq (\x y -> show (f x) == show (f y)) f pPair + +-- | True if the two given values would be printed identically +ppEq :: PP.Pretty x => PP.Pretty y => x -> y -> Bool +ppEq x y = show (PP.pretty x) == show (PP.pretty y) + +instance ShowF tp => Show (Quant tp (qbin :: QuantK WhichBinary)) where + show (Qu.All g) = + let + s1 = showF (g OriginalRepr) + s2 = showF (g PatchedRepr) + in if s1 == s2 then s1 else s1 ++ " vs. " ++ s2 + show (Qu.Single OriginalRepr a1) = showF a1 ++ " (original)" + show (Qu.Single PatchedRepr a1) = showF a1 ++ " (patched)" + + +ppBinaryPairEq :: + (tp Original -> tp Patched -> Bool) -> + (forall bin. tp bin -> PP.Doc a) -> + BinaryPair tp qbin -> + PP.Doc a +ppBinaryPairEq test f (Qu.All g) = case test (g OriginalRepr) (g PatchedRepr) of + True -> f $ g OriginalRepr + False -> f (g OriginalRepr) PP.<+> "(original) vs." PP.<+> f (g PatchedRepr) PP.<+> "(patched)" +ppBinaryPairEq _ f pPair = ppBinaryPair f pPair + +instance (forall bin. PP.Pretty (f bin)) => PP.Pretty (Quant f (qbin :: QuantK WhichBinary)) where + pretty = ppBinaryPairEq ppEq PP.pretty + +instance (Qu.HasReprK k, forall (bin :: k). PEM.ExprMappable sym (f bin)) => PEM.ExprMappable sym (Quant f qbin) where + mapExpr sym f pp = Qu.traverse (PEM.mapExpr sym f) pp \ No newline at end of file diff --git a/src/Pate/Equivalence.hs b/src/Pate/Equivalence.hs index fff7dcd7..3f97dbc9 100644 --- a/src/Pate/Equivalence.hs +++ b/src/Pate/Equivalence.hs @@ -526,8 +526,8 @@ eqDomPre :: eqDomPre sym stO stP (eqCtxHDR -> hdr) eqDom = do let st = PPa.PatchPair stO stP - maxRegion = TF.fmapF (\st' -> Const $ unSE $ simMaxRegion st') st - stackBase = TF.fmapF (\st' -> Const $ unSE $ simStackBase st') st + maxRegion = PPa.map (\st' -> Const $ unSE $ simMaxRegion st') st + stackBase = PPa.map (\st' -> Const $ unSE $ simStackBase st') st regsEq <- regDomRel hdr sym stO stP (PED.eqDomainRegisters eqDom) maxRegionsEq <- mkNamedAsm sym (PED.eqDomainMaxRegion eqDom) (bindExprPair maxRegion) @@ -563,7 +563,7 @@ eqDomPost sym stO stP eqCtx domPre domPost = do st = PPa.PatchPair stO stP hdr = eqCtxHDR eqCtx stackRegion = eqCtxStackRegion eqCtx - maxRegion = TF.fmapF (\st' -> Const $ unSE $ simMaxRegion st') st + maxRegion = PPa.map (\st' -> Const $ unSE $ simMaxRegion st') st regsEq <- regDomRel hdr sym stO stP (PED.eqDomainRegisters domPost) stacksEq <- memDomPost sym (MemEqAtRegion stackRegion) stO stP (PED.eqDomainStackMemory domPre) (PED.eqDomainStackMemory domPost) diff --git a/src/Pate/Interactive/Render/Proof.hs b/src/Pate/Interactive/Render/Proof.hs index f78a6d4c..e6fe05ac 100644 --- a/src/Pate/Interactive/Render/Proof.hs +++ b/src/Pate/Interactive/Render/Proof.hs @@ -26,6 +26,7 @@ import qualified Data.Parameterized.Map as MapF import qualified Data.Parameterized.NatRepr as PN import Data.Parameterized.Some ( Some(..) ) import qualified Data.Parameterized.TraversableF as TF +import qualified Data.Parameterized.TraversableFC as TFC import Data.Proxy (Proxy(..)) import qualified Data.Text as T import qualified Data.Vector as DV @@ -286,7 +287,7 @@ renderRegVal domain reg regOp = PPa.PatchPairSingle{} -> Just prettySlotVal _ -> Just prettySlotVal where - vals = TF.fmapF (\(Const x) -> Const $ PFI.groundMacawValue x) $ PPr.slRegOpValues regOp + vals = TFC.fmapFC (\(Const x) -> Const $ PFI.groundMacawValue x) $ PPr.slRegOpValues regOp ppDom = case PFI.regInGroundDomain (PED.eqDomainRegisters domain) reg of @@ -334,7 +335,7 @@ renderMemCellVal -> Maybe (PP.Doc a) renderMemCellVal domain cell memOp = do guard (PG.groundValue $ PPr.slMemOpCond memOp) - let vals = TF.fmapF (\(Const x) -> Const $ PFI.groundBV x) $ PPr.slMemOpValues memOp + let vals = PPa.map (\(Const x) -> Const $ PFI.groundBV x) $ PPr.slMemOpValues memOp let ppDom = case PFI.cellInGroundDomain domain cell of True -> PP.emptyDoc False -> PP.pretty "| Excluded" @@ -365,7 +366,7 @@ renderIPs st | (PG.groundValue $ PPr.slRegOpEquiv pcRegs) = PP.pretty (PPa.someC vals) | otherwise = PPa.ppPatchPairC PP.pretty vals where - vals = TF.fmapF (\(Const x) -> Const $ PFI.groundMacawValue x) (PPr.slRegOpValues pcRegs) + vals = PPa.map (\(Const x) -> Const $ PFI.groundMacawValue x) (PPr.slRegOpValues pcRegs) pcRegs = PPr.slRegState st ^. MC.curIP renderReturn @@ -384,7 +385,7 @@ renderCounterexample -> TP.UI TP.Element renderCounterexample ineqRes' = PPr.withIneqResult ineqRes' $ \ineqRes -> let - groundEnd = TF.fmapF (\(Const x) -> Const $ (PFI.groundBlockEnd (Proxy @arch)) x) $ PPr.slBlockExitCase (PPr.ineqSlice ineqRes) + groundEnd = PPa.map (\(Const x) -> Const $ (PFI.groundBlockEnd (Proxy @arch)) x) $ PPr.slBlockExitCase (PPr.ineqSlice ineqRes) renderInequalityReason rsn = case rsn of PEE.InequivalentRegisters -> @@ -397,11 +398,11 @@ renderCounterexample ineqRes' = PPr.withIneqResult ineqRes' $ \ineqRes -> TP.string "The original and patched programs have generated invalid post states" renderedContinuation = TP.column (catMaybes [ Just (text (PP.pretty "Next IP: " <> renderIPs (PPr.slBlockPostState (PPr.ineqSlice ineqRes)))) - , renderReturn (TF.fmapF (\(Const x) -> Const $ PFI.grndBlockReturn x) groundEnd) + , renderReturn (PPa.map (\(Const x) -> Const $ PFI.grndBlockReturn x) groundEnd) ]) in TP.grid [ [ renderInequalityReason (PPr.ineqReason ineqRes) ] - , [ text (PPa.ppPatchPairCEq (PP.pretty . PFI.ppExitCase) (TF.fmapF (\(Const x) -> Const $ PFI.grndBlockCase x) groundEnd)) ] + , [ text (PPa.ppPatchPairCEq (PP.pretty . PFI.ppExitCase) (PPa.map (\(Const x) -> Const $ PFI.grndBlockCase x) groundEnd)) ] , [ TP.h2 #+ [TP.string "Initial states"] ] , [ renderRegisterState (PPr.ineqPre ineqRes) (PPr.slRegState (PPr.slBlockPreState (PPr.ineqSlice ineqRes))) , renderMemoryState (PPr.ineqPre ineqRes) (PPr.slMemState (PPr.slBlockPreState (PPr.ineqSlice ineqRes))) diff --git a/src/Pate/Location.hs b/src/Pate/Location.hs index 9211fd47..7709d109 100644 --- a/src/Pate/Location.hs +++ b/src/Pate/Location.hs @@ -406,6 +406,6 @@ instance (LocationTraversable sym arch a, LocationTraversable sym arch b) => instance (forall bin. (LocationWitherable sym arch (f bin))) => LocationWitherable sym arch (PPa.PatchPair f) where - witherLocation sym pp f = TF.traverseF (\x -> witherLocation sym x f) pp + witherLocation sym pp f = PPa.traverse (\x -> witherLocation sym x f) pp diff --git a/src/Pate/Monad.hs b/src/Pate/Monad.hs index a221f9db..726569b3 100644 --- a/src/Pate/Monad.hs +++ b/src/Pate/Monad.hs @@ -1531,7 +1531,7 @@ withPair :: PB.BlockPair arch -> EquivM sym arch a -> EquivM sym arch a withPair pPair f = do env <- CMR.ask let env' = env { envParentBlocks = pPair:envParentBlocks env } - let entryPair = TF.fmapF (\b -> PB.functionEntryToConcreteBlock (PB.blockFunctionEntry b)) pPair + let entryPair = PPa.map (\b -> PB.functionEntryToConcreteBlock (PB.blockFunctionEntry b)) pPair CMR.local (\_ -> env' & PME.envCtxL . PMC.currentFunc .~ entryPair) f -- | Emit a trace event to the frontend @@ -1543,7 +1543,7 @@ traceBlockPair -> String -> EquivM sym arch () traceBlockPair bp msg = - emitEvent (PE.ProofTraceEvent callStack (TF.fmapF (Const . PB.concreteAddress) bp) (T.pack msg)) + emitEvent (PE.ProofTraceEvent callStack (PPa.map (Const . PB.concreteAddress) bp) (T.pack msg)) -- | Emit a trace event to the frontend -- @@ -1554,7 +1554,7 @@ traceBundle -> String -> EquivM sym arch () traceBundle bundle msg = do - let bp = TF.fmapF (Const . PB.concreteAddress . simInBlock) (simIn bundle) + let bp = PPa.map (Const . PB.concreteAddress . simInBlock) (simIn bundle) emitEvent (PE.ProofTraceEvent callStack bp (T.pack msg)) fnTrace :: diff --git a/src/Pate/Monad/PairGraph.hs b/src/Pate/Monad/PairGraph.hs index eb5e0c04..938182dc 100644 --- a/src/Pate/Monad/PairGraph.hs +++ b/src/Pate/Monad/PairGraph.hs @@ -167,7 +167,7 @@ initializePairGraph pPairs = foldM (\x y -> initPair x y) emptyPairGraph pPairs where initPair :: PairGraph sym arch -> PB.FunPair arch -> EquivM sym arch (PairGraph sym arch) initPair gr fnPair = - do let bPair = TF.fmapF PB.functionEntryToConcreteBlock fnPair + do let bPair = PPa.map PB.functionEntryToConcreteBlock fnPair withPair bPair $ do -- initial state of the pair graph: choose the universal domain that equates as much as possible let node = GraphNode (rootEntry bPair) @@ -205,7 +205,7 @@ initialDomainSpec (GraphNodeEntry blocks) = withTracing @"function_name" "initia dom <- initialDomain return (mempty, dom) initialDomainSpec (GraphNodeReturn fPair) = withTracing @"function_name" "initialDomainSpec" $ do - let blocks = TF.fmapF PB.functionEntryToConcreteBlock fPair + let blocks = PPa.map PB.functionEntryToConcreteBlock fPair withFreshVars blocks $ \_vars -> do dom <- initialDomain return (mempty, dom) diff --git a/src/Pate/PatchPair.hs b/src/Pate/PatchPair.hs index 79a4ad7f..1181b76f 100644 --- a/src/Pate/PatchPair.hs +++ b/src/Pate/PatchPair.hs @@ -17,9 +17,12 @@ {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ViewPatterns #-} module Pate.PatchPair ( PatchPair + , traverse + , map , pattern PatchPair , pattern PatchPairSingle , pattern PatchPairOriginal @@ -70,12 +73,10 @@ module Pate.PatchPair ( , asSingleton , toSingleton , zip - , jsonPatchPair - , w4SerializePair , WithBin(..) ) where -import Prelude hiding (zip) +import Prelude hiding (zip, map, traverse) import GHC.Stack (HasCallStack) import Control.Monad.Trans (lift) import Control.Monad.Trans.Maybe @@ -88,6 +89,9 @@ import Data.Functor.Const ( Const(..) ) import qualified Data.Kind as DK import Data.Parameterized.Classes import qualified Data.Parameterized.TraversableF as TF +import qualified Data.Parameterized.TraversableFC as TFC +import qualified Data.Quant as Qu + import qualified Prettyprinter as PP import qualified Data.Aeson as JSON import qualified Compat.Aeson as JSON @@ -110,14 +114,25 @@ import Control.Applicative ( (<|>) ) -- "full" (i.e. containing values for both binaries). Singleton 'PatchPair' values are used -- to handle cases where the control flow between the binaries has diverged and the verifier -- needs to handle each one independently. -data PatchPair (tp :: PB.WhichBinary -> DK.Type) = PatchPairCtor - { _pOriginal :: tp PB.Original - , _pPatched :: tp PB.Patched - } - | forall bin. PatchPairSingle (PB.WhichBinaryRepr bin) (tp bin) +type PatchPair (tp :: PB.WhichBinary -> DK.Type) = Qu.Quant tp Qu.ExistsK + +map :: (forall (bin :: PB.WhichBinary). f bin -> g bin) -> PatchPair f -> PatchPair g +map = TFC.fmapFC + +traverse :: Applicative m => (forall (bin :: PB.WhichBinary). f bin -> m (g bin)) -> PatchPair f -> m (PatchPair g) +traverse = TFC.traverseFC pattern PatchPair :: (tp PB.Original) -> (tp PB.Patched) -> PatchPair tp -pattern PatchPair a b = PatchPairCtor a b +pattern PatchPair a b <- ((\l -> case l of Qu.All f -> Just (f PB.OriginalRepr, f PB.PatchedRepr); _ -> Nothing) -> Just (a, b)) + where + PatchPair a b = Qu.QuantSome $ Qu.generateAll $ \case + PB.OriginalRepr -> a + PB.PatchedRepr -> b + +pattern PatchPairSingle :: () => forall x. PB.WhichBinaryRepr x -> tp x -> PatchPair tp +pattern PatchPairSingle repr x <- (Qu.Single repr x) + where + PatchPairSingle repr x = Qu.QuantSome $ Qu.QuantOne repr x pattern PatchPairOriginal :: tp PB.Original -> PatchPair tp pattern PatchPairOriginal a = PatchPairSingle PB.OriginalRepr a @@ -316,8 +331,8 @@ mkSingle bin a = PatchPairSingle bin a -- | Return the single 'tp' and which binary if the input is a singleton 'PatchPair'. -- 'asSingleton (toSingleton bin x) == (bin, x)' when 'x' contains an entry for 'bin' -- '(y,bin) <- asSingleton x; toSingleton bin y == x' when 'x' is a singleton -asSingleton :: PatchPairM m => PatchPair tp -> m (Pair PB.WhichBinaryRepr tp) -asSingleton (PatchPairSingle bin v) = return (Pair bin v) +asSingleton :: PatchPairM m => Qu.Quant tp qbin -> m (Pair PB.WhichBinaryRepr tp) +asSingleton (Qu.Single bin v) = return (Pair bin v) asSingleton _ = throwPairErr -- | Convert a 'PatchPair' into a singleton containing only @@ -383,7 +398,7 @@ insertWith bin v f = \case type PatchPairC tp = PatchPair (Const tp) pattern PatchPairC :: tp -> tp -> PatchPair (Const tp) -pattern PatchPairC a b = PatchPairCtor (Const a) (Const b) +pattern PatchPairC a b = PatchPair (Const a) (Const b) {-# COMPLETE PatchPairC, PatchPairSingle #-} {-# COMPLETE PatchPairC, PatchPairOriginal, PatchPairPatched #-} @@ -404,7 +419,7 @@ instance (forall tp. Show (t (f tp))) => ShowF (LiftF t f) type PatchPairF t tp = PatchPair (LiftF t tp) pattern PatchPairF :: t (tp PB.Original) -> t (tp PB.Patched) -> PatchPair (LiftF t tp) -pattern PatchPairF a b = PatchPairCtor (LiftF a) (LiftF b) +pattern PatchPairF a b = PatchPair (LiftF a) (LiftF b) {-# COMPLETE PatchPairF, PatchPairSingle #-} {-# COMPLETE PatchPairF, PatchPairOriginal, PatchPairPatched #-} @@ -512,56 +527,11 @@ forBins2 f = fmap unzipPatchPair2 $ forBins $ \bin -> do ppEq :: PP.Pretty x => PP.Pretty y => x -> y -> Bool ppEq x y = show (PP.pretty x) == show (PP.pretty y) -instance TestEquality tp => Eq (PatchPair tp) where - PatchPair o1 p1 == PatchPair o2 p2 - | Just Refl <- testEquality o1 o2 - , Just Refl <- testEquality p1 p2 - = True - PatchPairSingle _ a1 == PatchPairSingle _ a2 | Just Refl <- testEquality a1 a2 = True - _ == _ = False - -instance forall tp. (TestEquality tp, OrdF tp) => Ord (PatchPair tp) where - compare pp1 pp2 = case (pp1,pp2) of - (PatchPair o1 p1, PatchPair o2 p2) -> toOrdering $ (lexCompareF o1 o2 (compareF p1 p2)) - (PatchPairSingle _ s1, PatchPairSingle _ s2) -> toOrdering $ compareF s1 s2 - (PatchPairSingle{},PatchPair{}) -> LT - (PatchPair{},PatchPairSingle{}) -> GT - -instance TF.FunctorF PatchPair where - fmapF = TF.fmapFDefault - -instance TF.FoldableF PatchPair where - foldMapF = TF.foldMapFDefault - -instance (forall bin. PEM.ExprMappable sym (f bin)) => PEM.ExprMappable sym (PatchPair f) where - mapExpr sym f pp = TF.traverseF (PEM.mapExpr sym f) pp - -instance TF.TraversableF PatchPair where - traverseF f pp = case pp of - (PatchPair o p) -> PatchPair <$> f o <*> f p - (PatchPairSingle bin s) -> PatchPairSingle bin <$> f s - - -instance ShowF tp => Show (PatchPair tp) where - show (PatchPair a1 a2) = - let - s1 = showF a1 - s2 = showF a2 - in if s1 == s2 then s1 else s1 ++ " vs. " ++ s2 - show (PatchPairOriginal a1) = showF a1 ++ " (original)" - show (PatchPairPatched a1) = showF a1 ++ " (patched)" - -instance (forall bin. PP.Pretty (f bin)) => PP.Pretty (PatchPair f) where - pretty = ppPatchPairEq ppEq PP.pretty - ppPatchPair' :: (forall bin. tp bin -> PP.Doc a) -> PatchPair tp -> PP.Doc a ppPatchPair' f pPair = ppPatchPairEq (\x y -> show (f x) == show (f y)) f pPair - ppPatchPair :: (forall bin. tp bin -> PP.Doc a) -> PatchPair tp -> PP.Doc a -ppPatchPair f (PatchPair a1 a2) = f a1 PP.<+> "(original) vs." PP.<+> f a2 PP.<+> "(patched)" -ppPatchPair f (PatchPairOriginal a1) = f a1 PP.<+> "(original)" -ppPatchPair f (PatchPairPatched a1) = f a1 PP.<+> "(patched)" +ppPatchPair = PB.ppBinaryPair ppPatchPairEq :: (tp PB.Original -> tp PB.Patched -> Bool) -> @@ -590,49 +560,4 @@ ppPatchPairCEq f ppair@(PatchPair (Const o) (Const p)) = case o == p of True -> f o False -> ppPatchPairC f ppair ppPatchPairCEq f (PatchPairOriginal (Const a)) = f a PP.<+> "(original)" -ppPatchPairCEq f (PatchPairPatched (Const a)) = f a PP.<+> "(patched)" - - -jsonPatchPair :: - (forall bin. tp bin -> JSON.Value) -> PatchPair tp -> JSON.Value -jsonPatchPair f ppair = case ppair of - PatchPair o p -> JSON.object [ "original" JSON..= (f o), "patched" JSON..= (f p)] - PatchPairOriginal o -> JSON.object [ "original" JSON..= (f o) ] - PatchPairPatched p -> JSON.object [ "patched" JSON..= (f p) ] - - -instance (forall bin. JSON.ToJSON (tp bin)) => JSON.ToJSON (PatchPair tp) where - toJSON p = jsonPatchPair (JSON.toJSON) p - -w4SerializePair :: PatchPair f -> (forall bin. f bin -> W4S.W4S sym JSON.Value) -> W4S.W4S sym JSON.Value -w4SerializePair ppair f = case ppair of - PatchPair o p -> do - o_v <- f o - p_v <- f p - return $ JSON.object ["original" JSON..= o_v, "patched" JSON..= p_v] - PatchPairOriginal o -> do - o_v <- f o - return $ JSON.object ["original" JSON..= o_v] - PatchPairPatched p -> do - p_v <- f p - return $ JSON.object ["patched" JSON..= p_v] - -instance W4S.W4SerializableF sym f => W4S.W4Serializable sym (PatchPair f) where - w4Serialize ppair = w4SerializePair ppair w4SerializeF - - -instance (forall bin. PB.KnownBinary bin => W4Deserializable sym (f bin)) => W4Deserializable sym (PatchPair f) where - w4Deserialize_ v = do - JSON.Object o <- return v - let - case_pair = do - (vo :: f PB.Original) <- o .: "original" - (vp :: f PB.Patched) <- o .: "patched" - return $ PatchPair vo vp - case_orig = do - (vo :: f PB.Original) <- o .: "original" - return $ PatchPairOriginal vo - case_patched = do - (vp :: f PB.Patched) <- o .: "patched" - return $ PatchPairPatched vp - case_pair <|> case_orig <|> case_patched \ No newline at end of file +ppPatchPairCEq f (PatchPairPatched (Const a)) = f a PP.<+> "(patched)" \ No newline at end of file diff --git a/src/Pate/Proof.hs b/src/Pate/Proof.hs index 98d8ee65..bc784ec7 100644 --- a/src/Pate/Proof.hs +++ b/src/Pate/Proof.hs @@ -429,7 +429,7 @@ data BlockSliceRegOp sym tp where instance PEM.ExprMappable sym (BlockSliceRegOp sym tp) where mapExpr sym f regOp = BlockSliceRegOp - <$> TF.traverseF (PEM.mapExpr sym f) (slRegOpValues regOp) + <$> PPa.traverse (PEM.mapExpr sym f) (slRegOpValues regOp) <*> pure (slRegOpRepr regOp) <*> f (slRegOpEquiv regOp) @@ -448,7 +448,7 @@ data BlockSliceMemOp sym w where instance PEM.ExprMappable sym (BlockSliceMemOp sym w) where mapExpr sym f memOp = BlockSliceMemOp - <$> TF.traverseF (\(Const x) -> Const <$> W4H.mapExprPtr sym f x) (slMemOpValues memOp) + <$> PPa.traverse (\(Const x) -> Const <$> W4H.mapExprPtr sym f x) (slMemOpValues memOp) <*> f (slMemOpEquiv memOp) <*> f (slMemOpCond memOp) diff --git a/src/Pate/Proof/Instances.hs b/src/Pate/Proof/Instances.hs index ec460603..9b53d134 100644 --- a/src/Pate/Proof/Instances.hs +++ b/src/Pate/Proof/Instances.hs @@ -408,7 +408,7 @@ ppBlockSliceTransition :: PF.BlockSliceTransition grnd arch -> PP.Doc a ppBlockSliceTransition pre post bs = PP.vsep $ - [ "Block Exit Condition:" <+> PPa.ppPatchPairCEq (PP.pretty . ppExitCase) (TF.fmapF (\(Const x) -> Const $ grndBlockCase x) groundEnd) + [ "Block Exit Condition:" <+> PPa.ppPatchPairCEq (PP.pretty . ppExitCase) (PPa.map (\(Const x) -> Const $ grndBlockCase x) groundEnd) , "Initial register state:" , ppRegs pre (PF.slRegState $ PF.slBlockPreState bs) , "Initial memory state:" @@ -418,13 +418,13 @@ ppBlockSliceTransition pre post bs = PP.vsep $ , "Final memory state:" , ppMemCellMap post (PF.slMemState $ PF.slBlockPostState bs) , "Final IP:" <+> ppIPs (PF.slBlockPostState bs) - , case TF.fmapF (\(Const x) -> Const $ grndBlockReturn x) groundEnd of + , case PPa.map (\(Const x) -> Const $ grndBlockReturn x) groundEnd of PPa.PatchPairC (Just cont1) (Just cont2) -> "Function Continue Address:" <+> PPa.ppPatchPairCEq (PP.pretty . ppLLVMPointer) (PPa.PatchPairC cont1 cont2) _ -> PP.emptyDoc ] where - groundEnd = TF.fmapF (\(Const x) -> Const $ groundBlockEnd (Proxy @arch) x) $ PF.slBlockExitCase bs + groundEnd = PPa.map (\(Const x) -> Const $ groundBlockEnd (Proxy @arch) x) $ PF.slBlockExitCase bs ppIPs :: PA.ValidArch arch => @@ -434,7 +434,7 @@ ppIPs :: ppIPs st = let pcRegs = (PF.slRegState st) ^. MM.curIP - vals = TF.fmapF (\(Const x) -> Const $ groundMacawValue x) (PF.slRegOpValues pcRegs) + vals = PPa.map (\(Const x) -> Const $ groundMacawValue x) (PF.slRegOpValues pcRegs) in case PG.groundValue $ PF.slRegOpEquiv pcRegs of True -> PP.pretty $ PPa.someC vals False -> PPa.ppPatchPairC PP.pretty vals @@ -548,7 +548,7 @@ ppRegVal dom reg regOp = case PF.slRegOpRepr regOp of False -> Just $ ppSlotVal _ -> Just $ ppSlotVal where - vals = TF.fmapF (\(Const x) -> Const $ groundMacawValue x) $ PF.slRegOpValues regOp + vals = PPa.map (\(Const x) -> Const $ groundMacawValue x) $ PF.slRegOpValues regOp ppSlotVal = PP.pretty (showF reg) <> ":" <+> ppVals <+> ppDom ppDom = case regInGroundDomain dom reg of @@ -581,7 +581,7 @@ ppCellVal dom cell memOp = case PG.groundValue $ PF.slMemOpCond memOp of True -> Just $ ppSlotVal False -> Nothing where - vals = TF.fmapF (\(Const x) -> Const $ groundBV x) $ PF.slMemOpValues memOp + vals = PPa.map (\(Const x) -> Const $ groundBV x) $ PF.slMemOpValues memOp ppSlotVal = ppGroundCell cell <> ":" <+> ppVals <+> ppDom ppDom = case cellInGroundDomain dom cell of diff --git a/src/Pate/Proof/Operations.hs b/src/Pate/Proof/Operations.hs index 04c58606..7fdd6e8e 100644 --- a/src/Pate/Proof/Operations.hs +++ b/src/Pate/Proof/Operations.hs @@ -163,7 +163,7 @@ noTransition :: EquivM sym arch (PF.BlockSliceTransition sym arch) noTransition scope stIn blockEnd = do let - stOut = TF.fmapF (\st -> PS.SimOutput (PS.simInState st) blockEnd) stIn + stOut = PPa.map (\st -> PS.SimOutput (PS.simInState st) blockEnd) stIn bundle = PS.SimBundle stIn stOut simBundleToSlice scope bundle diff --git a/src/Pate/SimState.hs b/src/Pate/SimState.hs index 7725aece..96ce23ef 100644 --- a/src/Pate/SimState.hs +++ b/src/Pate/SimState.hs @@ -373,7 +373,7 @@ bundleInVars :: SimScope sym arch v -> SimBundle sym arch v -> (SimVars sym arch bundleInVars scope bundle = let (stO,stP) = asStatePair scope (simIn bundle) simInState in (SimVars stO, SimVars stP) simPair :: SimBundle sym arch v -> PB.BlockPair arch -simPair bundle = TF.fmapF simInBlock (simIn bundle) +simPair bundle = PPa.map simInBlock (simIn bundle) --------------------------------------- -- Variable binding diff --git a/src/Pate/Verification/PairGraph.hs b/src/Pate/Verification/PairGraph.hs index e8af885d..05d02a8c 100644 --- a/src/Pate/Verification/PairGraph.hs +++ b/src/Pate/Verification/PairGraph.hs @@ -157,6 +157,7 @@ import Data.Parameterized.Classes import Data.Set (Set) import qualified Data.Set as Set import Data.Word (Word32) +import qualified Data.Quant as Qu import qualified Lumberjack as LJ import Data.Parameterized (Some(..), Pair (..)) @@ -333,8 +334,10 @@ data WorkItem arch = (SingleNodeEntry arch PBi.Original) (SingleNodeEntry arch PBi.Patched) -- | Handle starting a split analysis from a diverging node. - | ProcessSplitCtor (Some (SingleNodeEntry arch)) - deriving (Eq, Ord) + | ProcessSplitCtor (Some (Qu.AsSingle (NodeEntry' arch))) + +deriving instance Eq (WorkItem arch) +deriving instance Ord (WorkItem arch) instance PA.ValidArch arch => Show (WorkItem arch) where show = \case @@ -362,8 +365,7 @@ pattern ProcessMerge:: SingleNodeEntry arch PBi.Original -> SingleNodeEntry arch pattern ProcessMerge sneO sneP <- (processMergeSinglePair -> Just (sneO, sneP)) pattern ProcessSplit :: SingleNodeEntry arch bin -> WorkItem arch -pattern ProcessSplit sne <- ProcessSplitCtor (Some sne) - +pattern ProcessSplit sne = ProcessSplitCtor (Some (Qu.AsSingle sne)) {-# COMPLETE ProcessNode, ProcessMergeAtExits, ProcessMergeAtEntry, ProcessSplit #-} {-# COMPLETE ProcessNode, ProcessMerge, ProcessSplit #-} @@ -436,7 +438,7 @@ data SyncData arch = -- | Defines exceptions for exits that would otherwise be considered sync points. -- In these cases, the single-sided analysis continues instead, with the intention -- that another sync point is encountered after additional instructions are executed - , _syncExceptions :: PPa.PatchPair (SetF (TupleF '(SingleNodeEntry arch, PB.BlockTarget arch))) + , _syncExceptions :: PPa.PatchPair (SetF (TupleF '(Qu.AsSingle (NodeEntry' arch), PB.BlockTarget arch))) -- Exits from the corresponding desync node that start the single-sided analysis , _syncDesyncExits :: PPa.PatchPair (SetF (PB.BlockTarget arch)) } @@ -451,7 +453,9 @@ syncPointBin :: SyncPoint arch bin -> PBi.WhichBinaryRepr bin syncPointBin sp = singleEntryBin $ syncPointNode sp instance TestEquality (SyncPoint arch) where - testEquality e1 e2 = testEquality (syncPointNode e1) (syncPointNode e2) + testEquality e1 e2 = case testEquality (syncPointNode e1) (syncPointNode e2) of + Just Refl -> Just Refl + Nothing -> Nothing instance OrdF (SyncPoint arch) where compareF sp1 sp2 = lexCompareF (syncPointBin sp1) (syncPointBin sp2) $ @@ -595,7 +599,7 @@ mkProcessSplit sne = do GraphNode dp_ne <- return $ singleNodeDivergePoint sne sne_dp <- toSingleNodeEntry (singleEntryBin sne) dp_ne guard (sne_dp == sne) - return (ProcessSplitCtor (Some sne)) + return (ProcessSplit sne) workItemNode :: WorkItem arch -> GraphNode arch @@ -1271,9 +1275,9 @@ pgMaybe msg Nothing = throwError $ PEE.PairGraphErr msg -- FIXME: do we need to support mismatched node kinds here? combineNodes :: SingleNodeEntry arch bin -> SingleNodeEntry arch (PBi.OtherBinary bin) -> Maybe (GraphNode arch) combineNodes node1 node2 = do - let ndPair = PPa.mkPair (singleEntryBin node1) node1 node2 - nodeO <- PPa.get PBi.OriginalRepr ndPair - nodeP <- PPa.get PBi.PatchedRepr ndPair + let ndPair = PPa.mkPair (singleEntryBin node1) (Qu.AsSingle node1) (Qu.AsSingle node2) + Qu.AsSingle nodeO <- PPa.get PBi.OriginalRepr ndPair + Qu.AsSingle nodeP <- PPa.get PBi.PatchedRepr ndPair -- it only makes sense to combine nodes that share a divergence point, -- where that divergence point will be used as the calling context for the -- merged point @@ -1458,7 +1462,7 @@ isSyncExit :: isSyncExit sne blkt@(PB.BlockTarget{}) = do excepts <- getSingleNodeData syncExceptions sne syncs <- getSingleNodeData syncPoints sne - let isExcept = Set.member (TupleF2 sne blkt) excepts + let isExcept = Set.member (TupleF2 (Qu.AsSingle sne) blkt) excepts case isExcept of True -> return Nothing False -> isCutAddressFor sne (PB.targetRawPC blkt) >>= \case @@ -1528,7 +1532,7 @@ filterSyncExits priority (ProcessSplit sne) blktPairs = pgValid $ do return x filterSyncExits priority (ProcessNode (GraphNode ne)) blktPairs = case asSingleNodeEntry ne of Nothing -> return blktPairs - Just (Some sne) -> do + Just (Some (Qu.AsSingle sne)) -> do let bin = singleEntryBin sne blkts <- mapM (PPa.get bin) blktPairs syncExits <- catMaybes <$> mapM (isSyncExit sne) blkts @@ -1543,14 +1547,14 @@ addReturnPointSync :: PPa.PatchPair (PB.BlockTarget arch) -> PairGraphM sym arch () addReturnPointSync priority ne blktPair = case asSingleNodeEntry ne of - Just (Some sne) -> do + Just (Some (Qu.AsSingle sne)) -> do let bin = singleEntryBin sne blkt <- PPa.get bin blktPair case PB.targetReturn blkt of Just ret -> do cuts <- getSingleNodeData syncCutAddresses sne excepts <- getSingleNodeData syncExceptions sne - let isExcept = Set.member (TupleF2 sne blkt) excepts + let isExcept = Set.member (TupleF2 (Qu.AsSingle sne) blkt) excepts case (not isExcept) && Set.member (PPa.WithBin (singleEntryBin sne) (PB.concreteAddress ret)) cuts of True -> do @@ -1600,7 +1604,7 @@ handleSingleSidedReturnTo :: NodeEntry arch -> PairGraphM sym arch () handleSingleSidedReturnTo priority ne = case asSingleNodeEntry ne of - Just (Some sne) -> do + Just (Some (Qu.AsSingle sne)) -> do let bin = singleEntryBin sne let dp = singleNodeDivergePoint sne syncAddrs <- getSyncData syncCutAddresses bin dp diff --git a/src/Pate/Verification/PairGraph/Node.hs b/src/Pate/Verification/PairGraph/Node.hs index 8562ab60..8665278a 100644 --- a/src/Pate/Verification/PairGraph/Node.hs +++ b/src/Pate/Verification/PairGraph/Node.hs @@ -16,9 +16,14 @@ {-# OPTIONS_GHC -fno-warn-orphans #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE FlexibleContexts #-} module Pate.Verification.PairGraph.Node ( - GraphNode(..) + GraphNode + , GraphNode'(..) + , NodeEntry' + , NodeReturn' , NodeEntry , NodeReturn , CallingContext @@ -61,12 +66,17 @@ module Pate.Verification.PairGraph.Node ( , singleNodeDivergence , toSingleNodeEntry , singleNodeAddr + , SingleNodeReturn + , SingleGraphNode + , pattern SingleNodeReturn ) where import Prettyprinter ( Pretty(..), sep, (<+>), Doc ) import qualified Data.Aeson as JSON import qualified Compat.Aeson as HMS -import qualified Data.Parameterized.TraversableF as TF + +import qualified Data.Quant as Qu +import Data.Quant ( Quant(..), QuantK, ExistsK ) import qualified Pate.Arch as PA import qualified Pate.Block as PB @@ -80,6 +90,7 @@ import Control.Monad (guard) import Data.Parameterized.Classes import Pate.Panic import qualified Pate.Address as PAd +import Data.Kind (Type) -- | Nodes in the program graph consist either of a pair of -- program points (GraphNode), or a synthetic node representing @@ -91,63 +102,87 @@ import qualified Pate.Address as PAd -- domain is propagated to all the potential return sites for -- that function, which are recorded separately in the -- "return vectors" map. -data GraphNode arch - = GraphNode (NodeEntry arch) - | ReturnNode (NodeReturn arch) +data GraphNode' arch (bin :: QuantK PB.WhichBinary) + = GraphNode (NodeEntry' arch bin) + | ReturnNode (NodeReturn' arch bin) deriving (Eq, Ord) -instance PA.ValidArch arch => JSON.ToJSON (GraphNode arch) where +type GraphNode arch = GraphNode' arch ExistsK + +instance PA.ValidArch arch => JSON.ToJSON (GraphNode' arch bin) where toJSON = \case GraphNode nd -> JSON.object [ ("graph_node_type", "entry"), "entry_body" JSON..= nd] ReturnNode nd -> JSON.object [ ("graph_node_type", "return"), "return_body" JSON..= nd] -instance PA.ValidArch arch => W4S.W4Serializable sym (GraphNode arch) where +instance PA.ValidArch arch => W4S.W4Serializable sym (GraphNode' arch bin) where w4Serialize r = return $ JSON.toJSON r -instance PA.ValidArch arch => W4S.W4Serializable sym (NodeEntry arch) where +instance PA.ValidArch arch => W4S.W4Serializable sym (NodeEntry' arch bin) where w4Serialize r = return $ JSON.toJSON r -data NodeContent arch e = - NodeContent { nodeContentCtx :: CallingContext arch, nodeContent :: e } - deriving (Eq, Ord) +data NodeContent arch (f :: PB.WhichBinary -> Type) (qbin :: QuantK PB.WhichBinary) = + NodeContent { nodeContentCtx :: CallingContext arch, nodeContent :: Quant f qbin } + +deriving instance (forall x. Eq (f x)) => Eq (NodeContent arch f qbin) +deriving instance (forall x. Ord (f x)) => Ord (NodeContent arch f qbin) + +instance (forall x. Eq (f x)) => TestEquality (NodeContent arch f) where + testEquality (NodeContent cctx1 x1) (NodeContent cctx2 x2) | cctx1 == cctx2, Just Refl <- testEquality x1 x2 = Just Refl + testEquality _ _ = Nothing -type NodeEntry arch = NodeContent arch (PB.BlockPair arch) +instance (forall x. Ord (f x)) => OrdF (NodeContent arch f) where + compareF (NodeContent cctx1 x1) (NodeContent cctx2 x2) = lexCompareF x1 x2 $ fromOrdering (compare cctx1 cctx2) -pattern NodeEntry :: CallingContext arch -> PB.BlockPair arch -> NodeEntry arch +type NodeEntry' arch = NodeContent arch (PB.ConcreteBlock arch) +type NodeEntry arch = NodeEntry' arch ExistsK + +instance Qu.QuantCoercible (NodeEntry' arch) where + coerceQuant (NodeEntry cctx blks) = NodeEntry cctx (Qu.coerceQuant blks) + +pattern NodeEntry :: CallingContext arch -> Quant (PB.ConcreteBlock arch) bin -> NodeEntry' arch bin pattern NodeEntry ctx bp = NodeContent ctx bp {-# COMPLETE NodeEntry #-} +nodeEntryRepr :: NodeEntry' arch qbin -> Qu.QuantRepr qbin +nodeEntryRepr ne = Qu.quantToRepr $ nodeBlocks ne -nodeBlocks :: NodeEntry arch -> PB.BlockPair arch +nodeBlocks :: NodeEntry' arch bin -> Quant (PB.ConcreteBlock arch) bin nodeBlocks = nodeContent -graphNodeContext :: NodeEntry arch -> CallingContext arch +graphNodeContext :: NodeEntry' arch bin -> CallingContext arch graphNodeContext = nodeContentCtx -type NodeReturn arch = NodeContent arch (PB.FunPair arch) +type NodeReturn' arch = NodeContent arch (PB.FunctionEntry arch) +type NodeReturn arch = NodeReturn' arch ExistsK + +nodeReturnRepr :: NodeReturn' arch qbin -> Qu.QuantRepr qbin +nodeReturnRepr ne = Qu.quantToRepr $ nodeFuns ne -nodeFuns :: NodeReturn arch -> PB.FunPair arch +nodeFuns :: NodeReturn' arch bin -> Quant (PB.FunctionEntry arch) bin nodeFuns = nodeContent -returnNodeContext :: NodeReturn arch -> CallingContext arch +returnNodeContext :: NodeReturn' arch bin -> CallingContext arch returnNodeContext = nodeContentCtx -pattern NodeReturn :: CallingContext arch -> PB.FunPair arch -> NodeReturn arch +pattern NodeReturn :: CallingContext arch -> Quant (PB.FunctionEntry arch) bin -> NodeReturn' arch bin pattern NodeReturn ctx bp = NodeContent ctx bp {-# COMPLETE NodeReturn #-} -graphNodeBlocks :: GraphNode arch -> PB.BlockPair arch +instance Qu.QuantCoercible (NodeReturn' arch) where + coerceQuant (NodeReturn cctx fns) = NodeReturn cctx (Qu.coerceQuant fns) + +graphNodeBlocks :: GraphNode' arch bin -> Quant (PB.ConcreteBlock arch) bin graphNodeBlocks (GraphNode ne) = nodeBlocks ne -graphNodeBlocks (ReturnNode ret) = TF.fmapF PB.functionEntryToConcreteBlock (nodeFuns ret) +graphNodeBlocks (ReturnNode ret) = Qu.map PB.functionEntryToConcreteBlock (nodeFuns ret) nodeContext :: GraphNode arch -> CallingContext arch nodeContext (GraphNode nd) = nodeContentCtx nd nodeContext (ReturnNode ret) = nodeContentCtx ret -pattern GraphNodeEntry :: PB.BlockPair arch -> GraphNode arch +pattern GraphNodeEntry :: Quant (PB.ConcreteBlock arch) bin -> GraphNode' arch bin pattern GraphNodeEntry blks <- (GraphNode (NodeContent _ blks)) -pattern GraphNodeReturn :: PB.FunPair arch -> GraphNode arch +pattern GraphNodeReturn :: Quant (PB.FunctionEntry arch) bin -> GraphNode' arch bin pattern GraphNodeReturn blks <- (ReturnNode (NodeContent _ blks)) {-# COMPLETE GraphNodeEntry, GraphNodeReturn #-} @@ -176,25 +211,28 @@ getDivergePoint nd = case nd of GraphNode (NodeEntry ctx _) -> divergePoint ctx ReturnNode (NodeReturn ctx _) -> divergePoint ctx -rootEntry :: PB.BlockPair arch -> NodeEntry arch +rootEntry :: PB.BinaryPair (PB.ConcreteBlock arch) qbin -> NodeEntry' arch qbin rootEntry pPair = NodeEntry (CallingContext [] Nothing) pPair -rootReturn :: PB.FunPair arch -> NodeReturn arch +rootReturn :: PB.BinaryPair (PB.FunctionEntry arch) qbin -> NodeReturn' arch qbin rootReturn pPair = NodeReturn (CallingContext [] Nothing) pPair -addContext :: PB.BlockPair arch -> NodeEntry arch -> NodeEntry arch -addContext newCtx ne@(NodeEntry (CallingContext ctx d) blks) = +addContext :: PB.BinaryPair (PB.ConcreteBlock arch) qbin1 -> NodeEntry' arch qbin2 -> NodeEntry' arch qbin2 +addContext newCtx' ne@(NodeEntry (CallingContext ctx d) blks) = case elem newCtx ctx of -- avoid recursive loops True -> ne False -> NodeEntry (CallingContext (newCtx:ctx) d) blks + where + newCtx = Qu.QuantSome newCtx' -- Strip diverge points from two-sided nodes. This is used so that -- merged nodes (which are two-sided) can meaningfully retain their -- diverge point, but it will be stripped on any subsequent nodes. -mkNextContext :: PPa.PatchPair a -> CallingContext arch -> CallingContext arch -mkNextContext (PPa.PatchPair{}) cctx = dropDivergePoint cctx -mkNextContext _ cctx = cctx +mkNextContext :: Quant a (bin :: QuantK PB.WhichBinary) -> CallingContext arch -> CallingContext arch +mkNextContext q cctx = case q of + Qu.All{} -> dropDivergePoint cctx + Qu.Single{} -> cctx dropDivergePoint :: CallingContext arch -> CallingContext arch dropDivergePoint (CallingContext cctx _) = CallingContext cctx Nothing @@ -295,30 +333,30 @@ splitGraphNode nd = do return (nodeO, nodeP) -- | Get the node corresponding to the entry point for the function -returnToEntry :: NodeReturn arch -> NodeEntry arch -returnToEntry (NodeReturn ctx fns) = NodeEntry (mkNextContext fns ctx) (TF.fmapF PB.functionEntryToConcreteBlock fns) +returnToEntry :: NodeReturn' arch bin -> NodeEntry' arch bin +returnToEntry (NodeReturn ctx fns) = NodeEntry (mkNextContext fns ctx) (Qu.map PB.functionEntryToConcreteBlock fns) -- | Get the return node that this entry would return to -returnOfEntry :: NodeEntry arch -> NodeReturn arch -returnOfEntry (NodeEntry ctx blks) = NodeReturn (mkNextContext blks ctx) (TF.fmapF PB.blockFunctionEntry blks) +returnOfEntry :: NodeEntry' arch bin -> NodeReturn' arch bin +returnOfEntry (NodeEntry ctx blks) = NodeReturn (mkNextContext blks ctx) (Qu.map PB.blockFunctionEntry blks) -- | For an intermediate entry point in a function, find the entry point -- corresponding to the function start -functionEntryOf :: NodeEntry arch -> NodeEntry arch -functionEntryOf (NodeEntry ctx blks) = NodeEntry (mkNextContext blks ctx) (TF.fmapF (PB.functionEntryToConcreteBlock . PB.blockFunctionEntry) blks) +functionEntryOf :: NodeEntry' arch bin -> NodeEntry' arch bin +functionEntryOf (NodeEntry ctx blks) = NodeEntry (mkNextContext blks ctx) (Qu.map (PB.functionEntryToConcreteBlock . PB.blockFunctionEntry) blks) instance PA.ValidArch arch => Show (CallingContext arch) where show c = show (pretty c) -instance PA.ValidArch arch => Show (NodeEntry arch) where +instance PA.ValidArch arch => Show (NodeEntry' arch bin) where show e = show (pretty e) -instance PA.ValidArch arch => Pretty (NodeEntry arch) where +instance PA.ValidArch arch => Pretty (NodeEntry' arch bin) where pretty e = case functionEntryOf e == e of True -> case graphNodeContext e of CallingContext [] _ -> pretty (nodeBlocks e) _ -> pretty (nodeBlocks e) <+> "[" <+> pretty (graphNodeContext e) <+> "]" - False -> PPa.ppPatchPair' PB.ppBlockAddr (nodeBlocks e) + False -> PB.ppBinaryPair' PB.ppBlockAddr (nodeBlocks e) <+> "[" <+> pretty (graphNodeContext (addContext (nodeBlocks (functionEntryOf e)) e)) <+> "]" instance PA.ValidArch arch => Pretty (NodeReturn arch) where @@ -348,7 +386,7 @@ tracePrettyNode nd msg = case nd of "" -> "Return" <+> pretty ret _ -> "Return" <+> pretty ret <+> PP.parens (pretty msg) -instance PA.ValidArch arch => JSON.ToJSON (NodeEntry arch) where +instance PA.ValidArch arch => JSON.ToJSON (NodeEntry' arch bin) where toJSON e = JSON.object [ "type" JSON..= entryType , "context" JSON..= graphNodeContext e @@ -360,7 +398,7 @@ instance PA.ValidArch arch => JSON.ToJSON (NodeEntry arch) where True -> "function_entry" False -> "function_body" -instance PA.ValidArch arch => JSON.ToJSON (NodeReturn arch) where +instance PA.ValidArch arch => JSON.ToJSON (NodeReturn' arch bin) where toJSON e = JSON.object [ "context" JSON..= returnNodeContext e , "functions" JSON..= nodeFuns e @@ -393,49 +431,40 @@ instance forall sym arch. PA.ValidArch arch => IsTraceNode '(sym, arch) "entryno -- | Equivalent to a 'NodeEntry' but necessarily a single-sided node. -- Converting a 'SingleNodeEntry' to a 'NodeEntry' is always defined, -- while converting a 'NodeEntry' to a 'SingleNodeEntry' is partial. -data SingleNodeEntry arch bin = - SingleNodeEntry - { singleEntryBin :: PB.WhichBinaryRepr bin - , _singleEntry :: NodeContent arch (PB.ConcreteBlock arch bin) - } -singleNodeAddr :: SingleNodeEntry arch bin -> PPa.WithBin (PAd.ConcreteAddress arch) bin -singleNodeAddr se = PPa.WithBin (singleEntryBin se) (PB.concreteAddress (singleNodeBlock se)) +type SingleNodeEntry arch bin = NodeEntry' arch (Qu.OneK bin) -mkSingleNodeEntry :: NodeEntry arch -> PB.ConcreteBlock arch bin -> SingleNodeEntry arch bin -mkSingleNodeEntry node blk = SingleNodeEntry (PB.blockBinRepr blk) (NodeContent (graphNodeContext node) blk) +pattern SingleNodeEntry :: CallingContext arch -> PB.ConcreteBlock arch bin -> SingleNodeEntry arch bin +pattern SingleNodeEntry cctx blk <- ((\l -> case l of NodeEntry cctx (Qu.Single _ blk) -> (cctx,blk)) -> (cctx,blk)) + where + SingleNodeEntry cctx blk = NodeEntry cctx (Qu.Single (PB.blockBinRepr blk) blk) -instance TestEquality (SingleNodeEntry arch) where - testEquality se1 se2 | EQF <- compareF se1 se2 = Just Refl - testEquality _ _ = Nothing +{-# COMPLETE SingleNodeEntry #-} -instance Eq (SingleNodeEntry arch bin) where - se1 == se2 = compare se1 se2 == EQ +singleEntryBin :: SingleNodeEntry arch bin -> PB.WhichBinaryRepr bin +singleEntryBin (nodeEntryRepr -> Qu.QuantOneRepr repr) = repr -instance Ord (SingleNodeEntry arch bin) where - compare (SingleNodeEntry _ se1) (SingleNodeEntry _ se2) = compare se1 se2 +singleNodeAddr :: SingleNodeEntry arch bin -> PPa.WithBin (PAd.ConcreteAddress arch) bin +singleNodeAddr se = PPa.WithBin (singleEntryBin se) (PB.concreteAddress (singleNodeBlock se)) -instance OrdF (SingleNodeEntry arch) where - compareF (SingleNodeEntry bin1 se1) (SingleNodeEntry bin2 se2) = - lexCompareF bin1 bin2 $ fromOrdering (compare se1 se2) +mkSingleNodeEntry :: NodeEntry' arch qbin -> PB.ConcreteBlock arch bin -> SingleNodeEntry arch bin +mkSingleNodeEntry node blk = SingleNodeEntry (graphNodeContext node) blk -instance PA.ValidArch arch => Show (SingleNodeEntry arch bin) where - show e = show (singleToNodeEntry e) singleNodeDivergePoint :: SingleNodeEntry arch bin -> GraphNode arch -singleNodeDivergePoint (SingleNodeEntry _ (NodeContent cctx _)) = case divergePoint cctx of +singleNodeDivergePoint (NodeEntry cctx _) = case divergePoint cctx of Just dp -> dp Nothing -> panic Verifier "singleNodeDivergePoint" ["missing diverge point for SingleNodeEntry"] -asSingleNodeEntry :: PPa.PatchPairM m => NodeEntry arch -> m (Some (SingleNodeEntry arch)) -asSingleNodeEntry (NodeEntry cctx bPair) = do - Pair bin blk <- PPa.asSingleton bPair +asSingleNodeEntry :: PPa.PatchPairM m => NodeEntry' arch qbin -> m (Some (Qu.AsSingle (NodeEntry' arch))) +asSingleNodeEntry (NodeEntry cctx blks) = do + Pair _ blk <- PPa.asSingleton blks case divergePoint cctx of - Just{} -> return $ Some (SingleNodeEntry bin (NodeContent cctx blk)) + Just{} -> return $ Some (Qu.AsSingle $ SingleNodeEntry cctx blk) Nothing -> PPa.throwPairErr singleNodeBlock :: SingleNodeEntry arch bin -> PB.ConcreteBlock arch bin -singleNodeBlock (SingleNodeEntry _ (NodeContent _ blk)) = blk +singleNodeBlock (SingleNodeEntry _ blk) = blk -- | Returns a 'SingleNodeEntry' for a given 'NodeEntry' if it has an entry -- for the given 'bin'. @@ -450,15 +479,14 @@ toSingleNodeEntry bin ne = do case toSingleNode bin ne of Just (NodeEntry cctx bPair) -> do blk <- PPa.get bin bPair - return $ SingleNodeEntry bin (NodeContent cctx blk) + return $ SingleNodeEntry cctx blk _ -> PPa.throwPairErr singleToNodeEntry :: SingleNodeEntry arch bin -> NodeEntry arch -singleToNodeEntry (SingleNodeEntry bin (NodeContent cctx v)) = - NodeEntry cctx (PPa.PatchPairSingle bin v) +singleToNodeEntry sne = Qu.coerceQuant sne singleNodeDivergence :: SingleNodeEntry arch bin -> GraphNode arch -singleNodeDivergence (SingleNodeEntry _ (NodeContent cctx _)) = case divergePoint cctx of +singleNodeDivergence (SingleNodeEntry cctx _) = case divergePoint cctx of Just dp -> dp Nothing -> panic Verifier "singleNodeDivergence" ["Unexpected missing divergence point"] @@ -466,12 +494,10 @@ combineSingleEntries' :: SingleNodeEntry arch PB.Original -> SingleNodeEntry arch PB.Patched -> Maybe (NodeEntry arch) -combineSingleEntries' (SingleNodeEntry _ eO) (SingleNodeEntry _ eP) = do - GraphNode divergeO <- divergePoint $ nodeContentCtx eO - GraphNode divergeP <- divergePoint $ nodeContentCtx eP +combineSingleEntries' (SingleNodeEntry cctxO blksO) (SingleNodeEntry cctxP blksP) = do + GraphNode divergeO <- divergePoint $ cctxO + GraphNode divergeP <- divergePoint $ cctxP guard $ divergeO == divergeP - let blksO = nodeContent eO - let blksP = nodeContent eP return $ mkNodeEntry divergeO (PPa.PatchPair blksO blksP) -- | Create a combined two-sided 'NodeEntry' based on @@ -486,4 +512,13 @@ combineSingleEntries :: Maybe (NodeEntry arch) combineSingleEntries sne1 sne2 = case singleEntryBin sne1 of PB.OriginalRepr -> combineSingleEntries' sne1 sne2 - PB.PatchedRepr -> combineSingleEntries' sne2 sne1 \ No newline at end of file + PB.PatchedRepr -> combineSingleEntries' sne2 sne1 + +type SingleNodeReturn arch bin = NodeReturn' arch (Qu.OneK bin) + +pattern SingleNodeReturn :: CallingContext arch -> PB.FunctionEntry arch bin -> SingleNodeReturn arch bin +pattern SingleNodeReturn cctx fn <- ((\l -> case l of NodeReturn cctx (Qu.Single _ fn) -> (cctx,fn)) -> (cctx,fn)) + where + SingleNodeReturn cctx fn = NodeReturn cctx (Qu.Single (PB.functionBinRepr fn) fn) + +type SingleGraphNode arch bin = GraphNode' arch (Qu.OneK bin) \ No newline at end of file diff --git a/src/Pate/Verification/StrongestPosts.hs b/src/Pate/Verification/StrongestPosts.hs index ff7e5e48..d4f7daf4 100644 --- a/src/Pate/Verification/StrongestPosts.hs +++ b/src/Pate/Verification/StrongestPosts.hs @@ -60,6 +60,7 @@ import qualified Data.Parameterized.TraversableF as TF import qualified Data.Parameterized.TraversableFC as TFC import Data.Parameterized.Nonce import qualified Data.Parameterized.Context as Ctx +import qualified Data.Quant as Qu import qualified What4.Expr as W4 import qualified What4.Interface as W4 @@ -205,11 +206,11 @@ runVerificationLoop env pPairs = do asEntry :: PB.FunPair arch -> NodeEntry arch asEntry fnPair = let - bPair = TF.fmapF PB.functionEntryToConcreteBlock fnPair + bPair = PPa.map PB.functionEntryToConcreteBlock fnPair in (rootEntry bPair) asRootEntry :: GraphNode arch -> Maybe (PB.FunPair arch ) -asRootEntry (GraphNode ne) = Just (TF.fmapF PB.blockFunctionEntry (nodeBlocks ne)) +asRootEntry (GraphNode ne) = Just (PPa.map PB.blockFunctionEntry (nodeBlocks ne)) asRootEntry (ReturnNode{}) = Nothing -- FIXME: clagged from initializePairGraph @@ -761,7 +762,7 @@ withWorkItem gr0 f = do let nd = workItemNode wi res <- subTraceLabel @"node" (printPriorityKind priority) nd $ atPriority priority Nothing $ do (mnext, gr2) <- case wi of - ProcessNode (GraphNode ne) | Just (Some sne) <- asSingleNodeEntry ne -> do + ProcessNode (GraphNode ne) | Just (Some (Qu.AsSingle sne)) <- asSingleNodeEntry ne -> do (evalPG gr1 $ isSyncNode sne) >>= \case True -> do gr2 <- execPG gr1 $ queueExitMerges (\pk -> mkPriority pk priority) (SyncAtStart sne) @@ -1016,7 +1017,7 @@ instance (PA.ValidArch arch, PSo.ValidSym sym) => W4S.W4Serializable sym (FinalR instance (PA.ValidArch arch, PSo.ValidSym sym) => W4S.W4Serializable sym (ConditionTraces sym arch) where w4Serialize (ConditionTraces p trT trF fps) = do - W4S.object [ "predicate" W4S..== p, "trace_true" W4S..= trT, "trace_false" W4S..= trF, "trace_footprint" W4S..= (TF.fmapF (\(Const(_,v)) -> Const v) fps) ] + W4S.object [ "predicate" W4S..== p, "trace_true" W4S..= trT, "trace_false" W4S..= trF, "trace_footprint" W4S..= (PPa.map (\(Const(_,v)) -> Const v) fps) ] instance (PSo.ValidSym sym, PA.ValidArch arch) => IsTraceNode '(sym,arch) "toplevel_result" where @@ -1086,9 +1087,9 @@ orphanReturnBundle scope pPair = withSym $ \sym -> do simOut_ <- IO.withRunInIO $ \runInIO -> PA.withStubOverride sym archInfo wsolver ov $ \f -> runInIO $ do - outSt <- liftIO $ f (TF.fmapF PS.simInState simIn_) + outSt <- liftIO $ f (PPa.map PS.simInState simIn_) blkend <- liftIO $ MCS.initBlockEnd (Proxy @arch) sym MCS.MacawBlockEndReturn - return $ TF.fmapF (\st' -> PS.SimOutput st' blkend) outSt + return $ PPa.map (\st' -> PS.SimOutput st' blkend) outSt return $ PS.SimBundle simIn_ simOut_ @@ -1195,7 +1196,7 @@ getFunctionAbs node d gr = do Nothing -> do -- this is some sub-block in a function, so use the domain for -- the function entry point - let fnPair = TF.fmapF PB.blockFunctionEntry (nodeBlocks node) + let fnPair = PPa.map PB.blockFunctionEntry (nodeBlocks node) case getCurrentDomain gr (GraphNode node) of Just preSpec -> PS.viewSpec preSpec $ \_ d' -> PPa.forBins $ \bin -> do vals <- PPa.get bin (PAD.absDomVals d') @@ -1258,7 +1259,7 @@ withAbsDomain node d gr f = do defaultInit = PA.validArchInitAbs archData liftIO $ PD.addOverrides defaultInit pfm absSt withParsedFunctionMap pfm_pair $ do - let fnBlks = TF.fmapF (PB.functionEntryToConcreteBlock . PB.blockFunctionEntry) (nodeBlocks node) + let fnBlks = PPa.map (PB.functionEntryToConcreteBlock . PB.blockFunctionEntry) (nodeBlocks node) PPa.catBins $ \bin -> do pfm <- PMC.parsedFunctionMap <$> getBinCtx fnBlks' <- PPa.get bin fnBlks @@ -1429,7 +1430,7 @@ withValidInit :: withValidInit scope bPair f = withPair bPair $ do let vars = PS.scopeVars scope - varsSt = TF.fmapF PS.simVarState vars + varsSt = PPa.map PS.simVarState vars validInit <- PVV.validInitState bPair varsSt validAbs <- PPa.catBins $ \bin -> do @@ -1568,7 +1569,7 @@ visitNode scope (workItemNode -> (ReturnNode fPair)) d gr0 = do priority <- thisPriority let vars = PS.scopeVars scope - varsSt = TF.fmapF PS.simVarState vars + varsSt = PPa.map PS.simVarState vars validState <- PVV.validInitState ret varsSt withCurrentAbsDomain (functionEntryOf node) gr0' $ withAssumptionSet validState $ do (asm, bundle) <- returnSiteBundle scope vars d ret @@ -2467,7 +2468,7 @@ triageBlockTarget scope bundle' paths currBlock st d blkts = withSym $ \sym -> d (_,PPa.PatchPairMismatch{}) -> handleDivergingPaths scope bundle' currBlock st d blkts _ | isMismatchedStubs stubPair -> handleDivergingPaths scope bundle' currBlock st d blkts (Just ecase, PPa.PatchPairJust rets) -> fmap (updateBranchGraph st blkts) $ do - let pPair = TF.fmapF PB.targetCall blkts + let pPair = PPa.map PB.targetCall blkts bundle <- PD.associateFrames bundle' ecase (hasStub stubPair) getSomeGroundTrace scope bundle d Nothing >>= emitTrace @"trace_events" traceBundle bundle (" Return target " ++ show rets) @@ -2487,7 +2488,7 @@ triageBlockTarget scope bundle' paths currBlock st d blkts = withSym $ \sym -> d MCS.MacawBlockEndReturn -> handleReturn scope bundle currBlock d gr _ -> do let - pPair = TF.fmapF PB.targetCall blkts + pPair = PPa.map PB.targetCall blkts nextNode = mkNodeEntry currBlock pPair traceBundle bundle "No return target identified" emitTrace @"message" "No return target identified" @@ -2696,7 +2697,7 @@ handleTerminalFunction :: PairGraph sym arch -> EquivM sym arch (PairGraph sym arch) handleTerminalFunction node gr = do - let fPair = TF.fmapF PB.blockFunctionEntry (nodeBlocks node) + let fPair = PPa.map PB.blockFunctionEntry (nodeBlocks node) return $ addTerminalNode gr (mkNodeReturn node fPair) getStubOverrideOne :: @@ -3021,7 +3022,7 @@ handleStub scope bundle currBlock d gr0_ pPair mpRetPair stubPair = fnTrace "han {- out <- PSi.applySimplifier unfold_simplifier (TF.fmapF PS.simOutState (PS.simOut bundle)) nextStPair_ <- liftIO $ f out nextStPair <- PSi.applySimplifier unfold_simplifier nextStPair_ -} - nextStPair <- liftIO $ f (TF.fmapF PS.simOutState (PS.simOut bundle)) + nextStPair <- liftIO $ f (PPa.map PS.simOutState (PS.simOut bundle)) PPa.forBins $ \bin -> do nextSt <- PPa.get bin nextStPair output <- PPa.get bin (PS.simOut bundle) @@ -3041,7 +3042,7 @@ handleReturn :: PairGraph sym arch -> EquivM sym arch (PairGraph sym arch) handleReturn scope bundle currBlock d gr = - do let fPair = TF.fmapF PB.blockFunctionEntry (nodeBlocks currBlock) + do let fPair = PPa.map PB.blockFunctionEntry (nodeBlocks currBlock) let ret = mkNodeReturn currBlock fPair let next = ReturnNode ret withTracing @"node" next $ @@ -3073,8 +3074,8 @@ mkSimBundle _pg node varsPair = fnTrace "mkSimBundle" $ do simOut_ <- mkSimOut simIn_ return $ TupleF2 simIn_ simOut_ let - simIn_pair = TF.fmapF (\(TupleF2 x _) -> x) results_pair - simOut_pair = TF.fmapF (\(TupleF2 _ x) -> x) results_pair + simIn_pair = PPa.map (\(TupleF2 x _) -> x) results_pair + simOut_pair = PPa.map (\(TupleF2 _ x) -> x) results_pair return (PS.SimBundle simIn_pair simOut_pair) mkSimOut :: diff --git a/src/Pate/Verification/StrongestPosts/CounterExample.hs b/src/Pate/Verification/StrongestPosts/CounterExample.hs index 92d3c16f..382b807c 100644 --- a/src/Pate/Verification/StrongestPosts/CounterExample.hs +++ b/src/Pate/Verification/StrongestPosts/CounterExample.hs @@ -289,7 +289,7 @@ instance (PA.ValidArch arch, PSo.ValidSym sym) => W4S.W4Serializable sym (TraceF instance (PA.ValidArch arch, PSo.ValidSym sym) => W4S.W4Serializable sym (TraceEvents sym arch) where w4Serialize (TraceEvents p pre post) = do - trace_pair <- PPa.w4SerializePair p $ \(TraceEventsOne rop evs) -> + trace_pair <- PB.w4SerializePair p $ \(TraceEventsOne rop evs) -> W4S.object [ "initial_regs" .= rop, "events" .= evs] W4S.object [ "precondition" .= pre, "postcondition" .= post, "traces" .= trace_pair ] diff --git a/src/What4/JSON.hs b/src/What4/JSON.hs index a23a9344..7df96aae 100644 --- a/src/What4/JSON.hs +++ b/src/What4/JSON.hs @@ -372,6 +372,22 @@ class W4Deserializable sym a where w4Deserialize :: W4Deserializable sym a => JSON.Value -> W4DS sym a w4Deserialize v = ask >>= \W4DSEnv{} -> w4Deserialize_ v +withSym :: (W4.IsExprBuilder sym => sym -> W4DS sym a) -> W4DS sym a +withSym f = do + W4DSEnv{} <- ask + sym <- asks w4dsSym + f sym + +lookupIdent :: Integer -> W4DS sym (Some (W4.SymExpr sym)) +lookupIdent ident = withSym $ \_sym -> do + ExprEnv env <- asks w4dsEnv + case Map.lookup ident env of + Just (Some e) -> return $ Some e + Nothing -> fail $ "lookupIdent: Missing identifier '" + ++ show ident + ++ "' from environment:\n" + ++ show (map (\(i, Some e) -> (i, W4.printSymExpr e)) $ Map.toList env) + instance W4Deserializable sym JSON.Value instance W4Deserializable sym String instance W4Deserializable sym Integer @@ -436,9 +452,8 @@ instance W4Deserializable sym (Some (ToDeserializable sym)) where <|> do JSON.Object o <- return v (ident :: Integer) <- o .: "symbolic_ident" - ExprEnv env <- asks w4dsEnv - Just (Some e) <- return $ Map.lookup ident env - return $ Some e + lookupIdent ident + instance forall tp sym. KnownRepr W4.BaseTypeRepr tp => W4Deserializable sym (ToDeserializable sym tp) where w4Deserialize_ v = do diff --git a/tests/TypesTestMain.hs b/tests/TypesTestMain.hs new file mode 100644 index 00000000..9994085e --- /dev/null +++ b/tests/TypesTestMain.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DataKinds #-} + +module Main ( main ) where + + +import qualified Test.Tasty as TT +import qualified Test.Tasty.HUnit as TTH + +import qualified QuantTest as QT + +main :: IO () +main = do + let tests = TT.testGroup "TypesTests" $ [ + TT.testGroup "Data.Quant" $ + [ TTH.testCase "testAll" $ QT.testAll + , TTH.testCase "testSingle" $ QT.testSingle + , TTH.testCase "testMap" $ QT.testMap + , TTH.testCase "testOrdering" $ QT.testOrdering + , TTH.testCase "testConversions" $ QT.testOrdering + ] + ] + TT.defaultMain tests \ No newline at end of file diff --git a/tests/types/QuantTest.hs b/tests/types/QuantTest.hs new file mode 100644 index 00000000..25acdb83 --- /dev/null +++ b/tests/types/QuantTest.hs @@ -0,0 +1,224 @@ +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeOperators #-} + +module QuantTest + ( testSingle + , testAll + , testMap + , testOrdering + , testConversions ) where + +import qualified Test.Tasty.HUnit as TTH + +import Data.Parameterized.Classes +import Data.Parameterized.WithRepr +import Data.Parameterized.Some +import qualified Data.Parameterized.Map as MapF +import Data.Parameterized.Map ( MapF ) + +import qualified Data.Parameterized.TotalMapF as TMF +import qualified Data.Quant as Qu +import Data.Quant ( Quant, QuantK, AllK, ExistsK, OneK ) + +data ColorK = RedK | BlueK + +data ColorRepr (tp :: ColorK) where + RedRepr :: ColorRepr RedK + BlueRepr :: ColorRepr BlueK + +instance KnownRepr ColorRepr RedK where + knownRepr = RedRepr + +instance KnownRepr ColorRepr BlueK where + knownRepr = BlueRepr + +instance Show (ColorRepr tp) where + show c = case c of + RedRepr -> "red" + BlueRepr -> "blue" + +instance ShowF ColorRepr + +instance IsRepr ColorRepr + +instance TestEquality ColorRepr where + testEquality r1 r2 = case (r1, r2) of + (RedRepr, RedRepr) -> Just Refl + (BlueRepr, BlueRepr) -> Just Refl + _ -> Nothing + +instance Eq (ColorRepr tp) where + _ == _ = True + +instance Ord (ColorRepr tp) where + compare _ _ = EQ + +instance OrdF ColorRepr where + compareF r1 r2 = case (r1, r2) of + (RedRepr, RedRepr) -> EQF + (BlueRepr, BlueRepr) -> EQF + (RedRepr, BlueRepr) -> LTF + (BlueRepr, RedRepr) -> GTF + + + +instance TMF.HasTotalMapF ColorRepr where + allValues = [Some RedRepr, Some BlueRepr] + +instance Qu.HasReprK ColorK where + type ReprOf = ColorRepr + +data Bucket (tp :: ColorK) = Bucket { bucketSize :: Int } + deriving (Eq, Ord, Show) + +instance ShowF Bucket + +instance Show (Qu.Quant Bucket tp) where + show qb = case qb of + Qu.Single c (Bucket sz) -> show c ++ " bucket: " ++ show sz + Qu.All f -> + let + Bucket redsz = f RedRepr + Bucket bluesz = f BlueRepr + in "red: " ++ show redsz ++ " and blue: " ++ show bluesz + +instance ShowF (Qu.Quant Bucket) + +data BucketContainer (tp :: QuantK ColorK) = BucketContainer { buckets :: [Quant Bucket tp]} + deriving (Show, Eq) + +instance ShowF BucketContainer + +instance Qu.QuantCoercible BucketContainer where + coerceQuant (BucketContainer bs) = BucketContainer (map Qu.coerceQuant bs) + +instance Qu.QuantConvertible BucketContainer where + convertQuant (BucketContainer bs) = BucketContainer <$> mapM Qu.convertQuant bs + +assertEq:: (Show f, Eq f) => f -> f -> IO () +assertEq f1 f2 = case f1 == f2 of + True -> return () + False -> TTH.assertFailure $ "assertEq failure: " ++ show f1 ++ " vs. " ++ show f2 + +assertInEq:: (Show f, Eq f) => f -> f -> IO () +assertInEq f1 f2 = case f1 == f2 of + True -> TTH.assertFailure $ "assertInEq failure: " ++ show f1 ++ " vs. " ++ show f2 + False -> return () + + +assertEquality' :: (ShowF f, TestEquality f) => f tp1 -> f tp2 -> IO (tp1 :~: tp2) +assertEquality' f1 f2 = case testEquality f1 f2 of + Just Refl -> return Refl + Nothing -> TTH.assertFailure $ "assertEquality failure: " ++ showF f1 ++ " vs. " ++ showF f2 + +assertEquality :: (ShowF f, TestEquality f) => f tp1 -> f tp2 -> IO () +assertEquality f1 f2 = assertEquality' f1 f2 >> return () + +qRed0 :: Qu.Quant Bucket (OneK RedK) +qRed0 = Qu.QuantOne RedRepr (Bucket 0) + +qAll0 :: Qu.Quant Bucket AllK +qAll0 = Qu.QuantAll (TMF.mapWithKey (\c _ -> case c of RedRepr -> Bucket 0; BlueRepr -> Bucket 1) TMF.totalMapRepr) + +testSingle :: IO () +testSingle = do + let (q2 :: Qu.Quant Bucket (OneK RedK)) = Qu.Single RedRepr (Bucket 0) + assertEq qRed0 q2 + assertEquality qRed0 q2 + + +testAll :: IO () +testAll = do + let (q2 :: Qu.Quant Bucket AllK) = Qu.All (\c -> case c of RedRepr -> Bucket 0; BlueRepr -> Bucket 1) + assertEq qAll0 q2 + assertEquality qAll0 q2 + +testMap :: IO () +testMap = do + assertEq (Qu.map (\(Bucket sz) -> Bucket (sz + 1)) qRed0) (Qu.QuantOne RedRepr (Bucket 1)) + assertInEq (Qu.map (\(Bucket sz) -> Bucket (sz + 1)) qRed0) (Qu.map (\(Bucket sz) -> Bucket (sz + 2)) qRed0) + + assertEq (Qu.map (\(Bucket sz) -> Bucket (sz + 1)) qAll0) (Qu.All (\c -> case c of RedRepr -> Bucket 1; BlueRepr -> Bucket 2)) + assertInEq (Qu.map (\(Bucket sz) -> Bucket (sz + 1)) qAll0) (Qu.map (\(Bucket sz) -> Bucket (sz + 2)) qAll0) + + +testOrdering :: IO () +testOrdering = do + let (bucketRed :: BucketContainer (OneK RedK)) = BucketContainer [qRed0] + let (bucketAll :: BucketContainer AllK) = BucketContainer [qAll0] + + let (m :: MapF (Qu.Quant Bucket) BucketContainer) = MapF.fromList [MapF.Pair qRed0 bucketRed, MapF.Pair qAll0 bucketAll] + assertEq (MapF.lookup qRed0 m) (Just bucketRed) + assertEq (MapF.lookup qAll0 m) (Just bucketAll) + + return () + + +testConversions :: IO () +testConversions = do + let (bucketRed :: BucketContainer (OneK RedK)) = BucketContainer [qRed0] + let (bucketRedEx :: BucketContainer ExistsK) = Qu.coerceQuant bucketRed + + () <- case bucketRedEx of + BucketContainer [] -> TTH.assertFailure "bucketRedEx: unexpected empty BucketContainer" + BucketContainer (Qu.All _:_) -> TTH.assertFailure $ "bucketRedEx: Qu.All unexpected match" + BucketContainer (Qu.Single repr x:_) -> do + Refl <- assertEquality' repr RedRepr + assertEq (Bucket 0) x + + let (mbucketRed :: Maybe (BucketContainer (OneK RedK))) = Qu.convertQuant bucketRedEx + assertEq (Just bucketRed) mbucketRed + + let (mbucketBlue :: Maybe (BucketContainer (OneK BlueK))) = Qu.convertQuant bucketRedEx + -- known invalid conversion + assertEq Nothing mbucketBlue + + let (bucketAll :: BucketContainer AllK) = BucketContainer [qAll0] + let (bucketAllEx :: BucketContainer ExistsK) = Qu.coerceQuant bucketAll + + case bucketAllEx of + BucketContainer [] -> TTH.assertFailure "bucketAllEx: unexpected empty BucketContainer" + BucketContainer (Qu.Single{}:_) -> TTH.assertFailure $ "bucketOneEx: Qu.Single unexpected match" + BucketContainer (Qu.All f:_) -> assertEq qAll0 (Qu.All f) + + let bucketRedFromAll :: BucketContainer (OneK RedK) = Qu.coerceQuant bucketAll + assertEq bucketRed bucketRedFromAll + + let (mbucketAll :: Maybe (BucketContainer AllK)) = Qu.convertQuant bucketAllEx + assertEq (Just bucketAll) mbucketAll + + let (mbucketRedFromExistsAll :: Maybe (BucketContainer (OneK RedK))) = Qu.convertQuant bucketRedEx + assertEq (Just bucketRed) mbucketRedFromExistsAll + + let (bucketMixed :: BucketContainer ExistsK) = BucketContainer [Qu.Single BlueRepr (Bucket 1), Qu.Single RedRepr (Bucket 0)] + let (mixedToAll :: Maybe (BucketContainer AllK)) = Qu.convertQuant bucketMixed + assertEq Nothing mixedToAll + + let (mixedToRed :: Maybe (BucketContainer (OneK RedK))) = Qu.convertQuant bucketMixed + assertEq Nothing mixedToRed + + let (mixedToBlue :: Maybe (BucketContainer (OneK BlueK))) = Qu.convertQuant bucketMixed + assertEq Nothing mixedToBlue + + let (bucketMixedRed :: BucketContainer ExistsK) = BucketContainer [Qu.QuantAny qAll0, Qu.QuantExists qRed0] + + let (mixedRedToRed :: Maybe (BucketContainer (OneK RedK))) = Qu.convertQuant bucketMixedRed + assertEq (Just (BucketContainer [qRed0, qRed0])) mixedRedToRed + + let (mixedRedToBlue :: Maybe (BucketContainer (OneK BlueK))) = Qu.convertQuant bucketMixed + assertEq Nothing mixedRedToBlue + + return () + +