Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added MDP CEs to policy search #34

Merged
merged 5 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paynt/quotient/mdp_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,21 @@ def fix_and_apply_policy_to_family(self, family, policy):
policy_fixed[state] = policy[state]

return policy_fixed,mdp


def apply_policy_to_family(self, family, policy):
policy_choices = []
for state,action in enumerate(policy):
if action is None:
for choice in self.state_action_choices[state]:
policy_choices += choice
else:
policy_choices += self.state_action_choices[state][action]
choices = stormpy.synthesis.policyToChoicesForFamily(policy_choices, family.selected_choices)

mdp = self.build_from_choice_mask(choices)

return mdp


def assert_mdp_is_deterministic(self, mdp, family):
Expand Down
66 changes: 48 additions & 18 deletions paynt/synthesizer/policy_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):

smt_solver = paynt.family.smt.SmtSolver(family)

unsat_conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(self.quotient)
unsat_conflict_generator.initialize()

mdp_subfamily = smt_solver.pick_assignment(family)

while mdp_subfamily is not None:
Expand All @@ -830,13 +833,31 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):
self.stat.iteration(mdp_subfamily.mdp)

if not result.sat:
# Potential for MDP CEs here
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, [list(range(family.num_holes))])
# MDP CE
requests = [(0, self.quotient.specification.all_properties()[0], None)]
choices = self.quotient.coloring.selectCompatibleChoices(mdp_subfamily.family)
model,state_map,choice_map = self.quotient.restrict_quotient(choices)
model = paynt.quotient.models.MDP(model,self.quotient,state_map,choice_map,mdp_subfamily)
conflicts = unsat_conflict_generator.construct_conflicts(family, mdp_subfamily, model, requests)

# conflicts = [list(range(family.num_holes))] # UNSAT without CE generalization

pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
self.explored += pruned
unsat_mdp_families.append(mdp_subfamily)

# MDP CE
unsat_family = family.copy()
for hole_index in range(self.quotient.design_space.num_holes):
if hole_index in conflicts[0]:
unsat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

mdp_subfamily.mdp = None
unsat_family.mdp = None
unsat_mdp_families.append(unsat_family)
else:
policy = self.quotient.scheduler_to_policy(result.result.scheduler, mdp_subfamily.mdp)
policy_fixed, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy)
policy, policy_quotient_mdp = self.quotient.fix_and_apply_policy_to_family(family, policy) # DTMC CE
# policy_quotient_mdp = self.quotient.apply_policy_to_family(family, policy) # MDP SAT CE
quotient_assignment = self.quotient.coloring.getChoiceToAssignment()
choice_to_hole_options = []
for choice in range(policy_quotient_mdp.choices):
Expand All @@ -845,16 +866,17 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):

coloring = stormpy.synthesis.Coloring(family.family, policy_quotient_mdp.model.nondeterministic_choice_indices, choice_to_hole_options)
quotient_container = paynt.quotient.quotient.DtmcFamilyQuotient(policy_quotient_mdp.model, family, coloring, self.quotient.specification.negate())
conflict_generator = paynt.synthesizer.conflict_generator.dtmc.ConflictGeneratorDtmc(quotient_container)
# conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(quotient_container)
conflict_generator = paynt.synthesizer.conflict_generator.dtmc.ConflictGeneratorDtmc(quotient_container) # DTMC CE
# conflict_generator = paynt.synthesizer.conflict_generator.mdp.ConflictGeneratorMdp(quotient_container) # MDP SAT CE
conflict_generator.initialize()
mdp_subfamily.constraint_indices = family.constraint_indices
requests = [(0, quotient_container.specification.all_properties()[0], None)]
model = quotient_container.build_assignment(mdp_subfamily)

# choices = coloring.selectCompatibleChoices(mdp_subfamily.family)
# model,state_map,choice_map = quotient_container.restrict_quotient(choices)
# model = paynt.quotient.models.MDP(model,quotient_container,state_map,choice_map,mdp_subfamily)
model = quotient_container.build_assignment(mdp_subfamily) # DTMC CE

# choices = coloring.selectCompatibleChoices(mdp_subfamily.family) # MDP SAT CE
# model,state_map,choice_map = quotient_container.restrict_quotient(choices) # MDP SAT CE
# model = paynt.quotient.models.MDP(model,quotient_container,state_map,choice_map,mdp_subfamily) # MDP SAT CE

conflicts = conflict_generator.construct_conflicts(family, mdp_subfamily, model, requests)
pruned = smt_solver.exclude_conflicts(family, mdp_subfamily, conflicts)
Expand All @@ -865,9 +887,10 @@ def synthesize_policy_for_family_using_ceg(self, family, prop):
if hole_index in conflicts[0]:
sat_family.hole_set_options(hole_index, mdp_subfamily.hole_options(hole_index))

sat_family.mdp = None
sat_mdp_families.append(sat_family)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy_fixed)
sat_mdp_to_policy_map.append(len(sat_mdp_policies))
sat_mdp_policies.append(policy)

mdp_subfamily = smt_solver.pick_assignment(family)

Expand All @@ -882,17 +905,24 @@ def evaluate_all(self, family, prop, keep_value_only=False):
assert not prop.reward, "expecting reachability probability propery"
game_solver = self.quotient.build_game_abstraction_solver(prop)
policy_tree = PolicyTree(family)
self.create_action_coloring()


### POLICY SEARCH TESTING
#self.create_action_coloring()

# choose policy search method
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_linear(policy_tree.root.family, prop)
# unsat, sat, policies, policy_map = self.synthesize_policy_for_family_using_ceg(policy_tree.root.family, prop)

# self.stat.synthesis_timer.stop()

# unsat_mdps_count = sum([s.size for s in unsat])
# sat_mdps_count = sum([s.size for s in sat])

# print(f'unSAT: {len(unsat)}')
# print(f'SAT: {len(sat)}')
# print(f'policies: {len(policies)}')
# print(self.stat.iterations_mdp)
# print(f'unSAT MDPs: {unsat_mdps_count}\tunSAT families: {len(unsat)}\tavg. unSAT family size: {round(unsat_mdps_count/len(unsat),2) if len(unsat) != 0 else "N/A"}')
# print(f'SAT MDPs: {sat_mdps_count}\tSAT families: {len(sat)}\tavg. SAT family size: {round(sat_mdps_count/len(sat),2) if len(sat) != 0 else "N/A"}')
# print(f'policies: {len(policies)}\tpolicy per SAT MDP: {round(len(policies)/sat_mdps_count,2) if sat_mdps_count != 0 else "N/A"}')
# print(f'iterations: {self.stat.iterations_mdp}')
# print(f'time: {round(self.stat.synthesis_timer.time,2)}s')

# self.double_check_policy_synthesis(unsat, sat, policies, policy_map, prop)
# exit()
Expand Down