diff --git a/symbolic/src/Data/Macaw/Symbolic/Testing.hs b/symbolic/src/Data/Macaw/Symbolic/Testing.hs index b8562fcf..53414032 100644 --- a/symbolic/src/Data/Macaw/Symbolic/Testing.hs +++ b/symbolic/src/Data/Macaw/Symbolic/Testing.hs @@ -3,6 +3,7 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} @@ -44,6 +45,7 @@ import qualified Control.Exception as X import qualified Control.Lens as L import Control.Lens ( (&), (%~) ) import Control.Monad ( when ) +import Control.Monad.Except ( runExceptT ) import qualified Data.Bits as Bits import qualified Data.BitVector.Sized as BVS import qualified Data.ByteString as BS @@ -77,6 +79,7 @@ import GHC.TypeNats ( type (<=) ) import qualified Lang.Crucible.Analysis.Postdom as CAP import qualified Lang.Crucible.Backend as CB import qualified Lang.Crucible.Backend.Online as CBO +import qualified Lang.Crucible.Backend.Prove as Prove import qualified Lang.Crucible.CFG.Core as CCC import qualified Lang.Crucible.CFG.Extension as CCE import qualified Lang.Crucible.FunctionHandle as CFH @@ -87,6 +90,8 @@ import qualified Lang.Crucible.Simulator as CS import qualified Lang.Crucible.Simulator.GlobalState as CSG import qualified Lang.Crucible.Simulator.PathSatisfiability as CSP import qualified Lang.Crucible.Types as CT +import qualified Lang.Crucible.Utils.Seconds as Sec +import qualified Lang.Crucible.Utils.Timeout as CTO import qualified System.FilePath as SF import qualified System.IO as IO import qualified What4.BaseTypes as WT @@ -315,42 +320,6 @@ functionNameFromByteString :: BS.ByteString -> WF.FunctionName functionNameFromByteString = WF.functionNameFromText . Text.decodeUtf8With Text.lenientDecode -proveOneGoal - :: ( CB.IsSymInterface sym - , sym ~ WE.ExprBuilder t st fs - ) - => WS.SolverAdapter st - -> sym - -> CB.Assumptions sym - -> WL.LabeledPred (WI.Pred sym) CS.SimError - -> IO () -proveOneGoal goalSolver sym asmps lp = do - assumptions <- CB.assumptionsPred sym asmps - goal <- WI.notPred sym (lp L.^. WL.labeledPred) - WS.solver_adapter_check_sat goalSolver sym WS.defaultLogData [assumptions, goal] $ \satRes -> - case satRes of - WSR.Unsat {} -> return () - WSR.Sat {} -> error ("Failed to prove goal: " ++ show (lp L.^. WL.labeledPredMsg)) - WSR.Unknown {} -> error ("Failed to prove goal: " ++ show (lp L.^. WL.labeledPredMsg)) - return () - -proveGoals - :: ( CB.IsSymInterface sym, sym ~ WE.ExprBuilder t st fs ) - => WS.SolverAdapter st - -> sym - -> Maybe (CB.Goals (CB.Assumptions sym) (CB.Assertion sym)) - -> IO () -proveGoals goalSolver sym = mapM_ (go mempty) - where - go asmps gs = - case gs of - CB.Assuming as gs1 -> go (asmps <> as) gs1 - CB.Prove lp -> proveOneGoal goalSolver sym asmps lp - CB.ProveConj g1 g2 -> do - go asmps g1 - go asmps g2 - return () - -- | Convert the given function into a Crucible CFG, symbolically execute it, -- and treat the return value as an assertion to be verified. -- @@ -431,8 +400,19 @@ simulateAndVerify goalSolver logger bak execFeatures archInfo archVals binfo mmP CS.PartialRes {} -> return SimulationPartial CS.TotalRes gp -> do when ("test_and_verify_" `Text.isPrefixOf` WF.functionName funName) $ do - goals <- CB.getProofObligations bak - proveGoals goalSolver sym goals + let timeout = CTO.Timeout (Sec.secondsFromInt 5) + let prover = Prove.offlineProver timeout sym WS.defaultLogData goalSolver + let strat = Prove.ProofStrategy prover Prove.keepGoing + let printer = Prove.ProofConsumer $ \goal res -> do + let lp = CB.proofGoal goal + case res of + Prove.Proved {} -> return () + Prove.Disproved {} -> error ("Failed to prove goal: " ++ show (lp L.^. WL.labeledPredMsg)) + Prove.Unknown {} -> error ("Failed to prove goal: " ++ show (lp L.^. WL.labeledPredMsg)) + runExceptT (Prove.proveCurrentObligations bak strat printer) >>= + \case + Left CTO.TimedOut -> error "Timeout when proving goals!" + Right () -> pure () postMem <- case CSG.lookupGlobal memVar (gp L.^. CS.gpGlobals) of Just postMem -> pure postMem