Skip to content

Commit

Permalink
Merge pull request #32 from TheGreatfpmK/new-master
Browse files Browse the repository at this point in the history
Added 2 new methods for policy search
  • Loading branch information
TheGreatfpmK authored Dec 18, 2023
2 parents 33a8987 + 63e446b commit e80481c
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 5 deletions.
2 changes: 1 addition & 1 deletion paynt/quotient/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def hole_simple(self):
hole_to_states = [0 for _ in range(num_holes)]
for state in range(self.states):
quotient_state = self.quotient_state_map[state]
for hole in quotient_container.state_to_holes[quotient_state]:
for hole in self.quotient_container.state_to_holes[quotient_state]:
hole_to_states[hole] += 1
self.hole_is_simple = [hole_to_states[hole] <= 1 for hole in range(num_holes)]
return self.hole_is_simple
Expand Down
2 changes: 1 addition & 1 deletion paynt/quotient/quotient.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_with_second_coloring(self, family, main_coloring, main_family):

# select actions compatible with the family and restrict the quotient
choices_alt = self.coloring.selectCompatibleChoices(family.family)
choices_main = main_coloring.selectCompatibleChoices(main_family)
choices_main = main_coloring.selectCompatibleChoices(main_family.family)

choices = choices_main.__and__(choices_alt)
main_family.mdp = self.build_from_choice_mask(choices)
Expand Down
10 changes: 9 additions & 1 deletion paynt/synthesizer/conflict_generator/storm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,18 @@ def name(self):

def initialize(self):
quotient_relevant_holes = self.quotient.coloring.getStateToHoles()
# TODO this is not a nice solution, it would be nice to remake the Storm code to work with bitvectors
state_to_holes = []
for hole in quotient_relevant_holes:
holes = set()
for i in range(hole.size()):
if hole.get(i):
holes.add(i)
state_to_holes.append(holes)
formulae = self.quotient.specification.stormpy_formulae()
self.counterexample_generator = stormpy.synthesis.CounterexampleGenerator(
self.quotient.quotient_mdp, self.quotient.design_space.num_holes,
quotient_relevant_holes, formulae)
state_to_holes, formulae)


def construct_conflicts(self, family, assignment, dtmc, conflict_requests, accepting_assignment):
Expand Down
164 changes: 162 additions & 2 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
import paynt.quotient.models
import paynt.synthesizer.synthesizer

from .conflict_generator.storm import ConflictGeneratorStorm
from .conflict_generator.mdp import ConflictGeneratorMdp
import paynt.family.smt

import paynt.verification.property_result
from paynt.verification.property import Property
import paynt.quotient.quotient

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -584,6 +589,8 @@ def create_action_coloring(self):
self.action_coloring = coloring
return

###############################
#### POLICY SEARCH SECTION ####

def update_scores(self, score_lists, selection):
for hole, score_list in score_lists.items():
Expand Down Expand Up @@ -725,14 +732,167 @@ def synthesize_policy_for_family(self, family, prop, all_sat=False, iteration_li
sat_mdp_policies[mdp_index] = policy

return False, unsat_mdp_families, sat_mdp_families, sat_mdp_policies


def double_check_policy_synthesis(self, unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map, prop):

for unsat_family in unsat_mdp_families:
self.quotient.build(unsat_family)
result = unsat_family.mdp.model_check_property(prop)
assert not result.sat, "double check fail"
unsat_family.mdp = None

for sat_index, sat_family in enumerate(sat_mdp_families):
self.quotient.build(sat_family)
sat_policy = sat_mdp_policies[sat_mdp_to_policy_map[sat_index]]
SynthesizerPolicyTree.double_check_policy(self.quotient, sat_family, prop, sat_policy)
sat_family.mdp = None


def synthesize_policy_for_family_linear(self, family, prop):
'''
Synthesize policies for mdps in family in linear time with respect to family size
:returns a list of UNSAT MDP families
:returns a list of SAT MDP families
:returns list of policies for SAT MDP families
:returns list that maps each SAT MDP family to its policy
'''
sat_mdp_families = []
sat_mdp_policies = []
sat_mdp_to_policy_map = []
unsat_mdp_families = []

mdp_families = []

for hole_assignment in family.all_combinations():
subfamily = family.copy()
for hole_index, hole_option in enumerate(hole_assignment):
subfamily.hole_set_options(hole_index, [hole_option])
mdp_families.append(subfamily)

for family in mdp_families:
self.quotient.build(family)

result = family.mdp.model_check_property(prop)
self.stat.iteration_mdp(family.mdp.states)
self.explore(family)
if not result.sat:
family.mdp = None
unsat_mdp_families.append(family)
continue

policy = self.quotient.scheduler_to_policy(result.result.scheduler, family.mdp)

family.mdp = None

for index, sat_policy in enumerate(sat_mdp_policies):
merged_policy = merge_policies([sat_policy, policy])
if merged_policy is None:
continue
else:
sat_mdp_policies[index] = merged_policy
sat_mdp_families.append(family)
sat_mdp_to_policy_map.append(index)
break
else:
sat_mdp_families.append(family)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy)

return unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map


def synthesize_policy_for_family_using_ceg(self, family, prop):
'''
Synthesize policies for mdps in family using counter-example generalization
:returns a list of UNSAT MDP families
:returns a list of SAT MDP families
:returns list of policies for SAT MDP families
:returns list that maps each SAT MDP family to its policy
'''

sat_mdp_families = []
sat_mdp_policies = []
sat_mdp_to_policy_map = []
unsat_mdp_families = []

self.quotient.build(family)

smt_solver = paynt.family.smt.SmtSolver(family)

mdp_subfamily = smt_solver.pick_assignment(family)

while mdp_subfamily is not None:

self.quotient.build(mdp_subfamily)

result = mdp_subfamily.mdp.model_check_property(prop)
self.stat.iteration_mdp(mdp_subfamily.mdp.states)

if result.sat == False:
# Potential for MDP CEs here
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, [list(range(family.num_holes))])
self.explored += pruned
unsat_mdp_families.append(mdp_subfamily)
elif result.sat == True:
policy = self.quotient.scheduler_to_policy(result.result.scheduler, mdp_subfamily.mdp)
policy_fixed, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy)
quotient_assignment = self.quotient.coloring.getChoiceToAssignment()
choice_to_hole_options = []
for choice in range(policy_quotient_mdp.choices):
quotient_choice = policy_quotient_mdp.quotient_choice_map[choice]
choice_to_hole_options.append(quotient_assignment[quotient_choice])

coloring = stormpy.synthesis.Coloring(family.family, policy_quotient_mdp.model.nondeterministic_choice_indices, choice_to_hole_options)
quotient_container = paynt.quotient.quotient.DtmcFamilyQuotient(policy_quotient_mdp.model, family, coloring, self.quotient.specification.negate())
conflict_generator = ConflictGeneratorStorm(quotient_container)
conflict_generator.initialize()
mdp_subfamily.constraint_indices = family.constraint_indices
requests = [(0, quotient_container.specification.all_properties()[0], result.result, None)]
dtmc = quotient_container.build_chain(mdp_subfamily)
conflicts, _ = conflict_generator.construct_conflicts(family, mdp_subfamily, dtmc, requests, None)
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
self.explored += pruned

sat_family = family.copy()
for hole_index in range(self.quotient.design_space.num_holes):
if hole_index in conflicts[0]:
sat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

sat_mdp_families.append(sat_family)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy_fixed)
else:
assert False, "result for MDP model checking is not SAT nor UNSAT"

mdp_subfamily = smt_solver.pick_assignment(family)

return unsat_mdp_families, sat_mdp_families, sat_mdp_policies, sat_mdp_to_policy_map

#### POLICY SEARCH SECTION END ####
###################################



def synthesize_policy_tree(self, family):
prop = self.quotient.get_property()
game_solver = self.quotient.build_game_abstraction_solver(prop)
policy_tree = PolicyTree(family)
# self.create_action_coloring()
self.create_action_coloring()


### POLICY SEARCH TESTING
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_linear(policy_tree.root.family, prop)
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_using_ceg(policy_tree.root.family, prop)

# print(f'unSAT: {len(unsat)}')
# print(f'SAT: {len(sat)}')
# print(f'policies: {len(policies)}')
# print(self.stat.iterations_mdp)

# self.double_check_policy_synthesis(unsat, sat, policies, policy_map, prop)
# exit()
###

if False:
self.quotient.build(policy_tree.root.family)
Expand Down Expand Up @@ -762,7 +922,7 @@ def synthesize_policy_tree(self, family):
# refine
suboptions,subfamilies = self.split(family, prop, result.hole_selection, result.splitter)
if policy_tree_node != policy_tree.root:
family.mdp = None # memory optimization
family.mdp = None # memory optimization
policy_tree_node.split(result.splitter,suboptions,subfamilies)
policy_tree_leaves = policy_tree_leaves + policy_tree_node.child_nodes

Expand Down

0 comments on commit e80481c

Please sign in to comment.