Skip to content

Commit

Permalink
Added CEG-based policy search
Browse files Browse the repository at this point in the history
  • Loading branch information
TheGreatfpmK committed Dec 18, 2023
1 parent cbf70cb commit 6d1bdfb
Showing 1 changed file with 70 additions and 14 deletions.
84 changes: 70 additions & 14 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import paynt.verification.property_result
from paynt.verification.property import Property
import paynt.quotient.quotient

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -588,6 +589,8 @@ def create_action_coloring(self):
self.action_coloring = coloring
return

###############################
#### POLICY SEARCH SECTION ####

def update_scores(self, score_lists, selection):
for hole, score_list in score_lists.items():
Expand Down Expand Up @@ -731,13 +734,26 @@ def synthesize_policy_for_family(self, family, prop, all_sat=False, iteration_li
return False, unsat_mdp_families, sat_mdp_families, sat_mdp_policies


def double_check_policy_synthesis(self, unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map, prop):

for unsat_family in unsat_mdp_families:
self.quotient.build(unsat_family)
result = unsat_family.mdp.model_check_property(prop)
assert not result.sat, "double check fail"

for sat_index, sat_family in enumerate(sat_mdp_families):
self.quotient.build(sat_family)
sat_policy = sat_mdp_policies[sat_mdp_to_policy_map[sat_index]]
SynthesizerPolicyTree.double_check_policy(self.quotient, sat_family, prop, sat_policy)


def synthesize_policy_for_family_linear(self, family, prop):
'''
Synthesize policies for mdps in family in linear time with respect to family size
:returns a list of UNSAT MDPs
:returns a list of SAT MDPs
:returns list of policies for SAT MDPs
:returns list that maps each SAT MDP to its policy
:returns a list of UNSAT MDP families
:returns a list of SAT MDP families
:returns list of policies for SAT MDP families
:returns list that maps each SAT MDP family to its policy
'''
sat_mdp_families = []
sat_mdp_policies = []
Expand All @@ -757,6 +773,7 @@ def synthesize_policy_for_family_linear(self, family, prop):

result = family.mdp.model_check_property(prop)
self.stat.iteration_mdp(family.mdp.states)
self.explore(family)
if not result.sat:
unsat_mdp_families.append(family)

Expand All @@ -778,12 +795,21 @@ def synthesize_policy_for_family_linear(self, family, prop):

return unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map


def synthesize_policy_for_family_using_ceg(self, family, prop):
'''
Synthesize policies for mdps in family using counter-example generalization
:returns a list of UNSAT MDP families
:returns a list of SAT MDP families
:returns list of policies for SAT MDP families
:returns list that maps each SAT MDP family to its policy
'''

sat_mdp_families = []
sat_mdp_policies = []
sat_mdp_to_policy_map = []
unsat_mdp_families = []

self.quotient.build(family)
conflict_generator = ConflictGeneratorStorm(self.quotient)
conflict_generator.initialize()

smt_solver = paynt.family.smt.SmtSolver(family)

Expand All @@ -800,14 +826,44 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):
# Potential for MDP CEs here
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, [list(range(family.num_holes))])
self.explored += pruned
unsat_mdp_families.append(mdp_subfamily)
elif result.sat == True:
policy = self.quotient.scheduler_to_policy(result.result.scheduler, mdp_subfamily.mdp)
policy_fixed, dtmc_family_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(mdp_subfamily, policy)

policy_fixed, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy)
quotient_assignment = self.quotient.coloring.getChoiceToAssignment()
choice_to_hole_options = []
for choice in range(policy_quotient_mdp.choices):
quotient_choice = policy_quotient_mdp.quotient_choice_map[choice]
choice_to_hole_options.append(quotient_assignment[quotient_choice])

coloring = stormpy.synthesis.Coloring(family.family, policy_quotient_mdp.model.nondeterministic_choice_indices, choice_to_hole_options)
quotient_container = paynt.quotient.quotient.DtmcFamilyQuotient(policy_quotient_mdp.model, family, coloring, self.quotient.specification.negate())
conflict_generator = ConflictGeneratorStorm(quotient_container)
conflict_generator.initialize()
mdp_subfamily.constraint_indices = family.constraint_indices
requests = [(0, quotient_container.specification.all_properties()[0], result.result, None)]
dtmc = quotient_container.build_chain(mdp_subfamily)
conflicts, _ = conflict_generator.construct_conflicts(family, mdp_subfamily, dtmc, requests, None)
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
self.explored += pruned

sat_family = family.copy()
for hole_index in range(self.quotient.design_space.num_holes):
if hole_index in conflicts[0]:
sat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

sat_mdp_families.append(sat_family)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy_fixed)
else:
assert False, "result for MDP model checking is not SAT nor UNSAT"

mdp_subfamily = smt_solver.pick_assignment(family)

return unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map

#### POLICY SEARCH SECTION END ####
###################################



Expand All @@ -817,15 +873,15 @@ def synthesize_policy_tree(self, family):
policy_tree = PolicyTree(family)
self.create_action_coloring()

self.synthesize_policy_for_family_using_ceg(policy_tree.root.family, prop)
exit()


unsat, sat, policies, policy_map = self.synthesize_policy_for_family_linear(policy_tree.root.family, prop)
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_using_ceg(policy_tree.root.family, prop)

print(f'unSAT: {len(unsat)}')
print(f'SAT: {len(sat)}')
print(f'policies: {len(policies)}')
print(self.stat.iterations_mdp)

self.double_check_policy_synthesis(unsat, sat, policies, policy_map, prop)
exit()

if False:
Expand Down Expand Up @@ -856,7 +912,7 @@ def synthesize_policy_tree(self, family):
# refine
suboptions,subfamilies = self.split(family, prop, result.hole_selection, result.splitter)
if policy_tree_node != policy_tree.root:
family.mdp = None # memory optimization
family.mdp = None # memory optimization
policy_tree_node.split(result.splitter,suboptions,subfamilies)
policy_tree_leaves = policy_tree_leaves + policy_tree_node.child_nodes

Expand Down

0 comments on commit 6d1bdfb

Please sign in to comment.