Skip to content

Commit

Permalink
MDP controller synthesis via SMT coloring
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Jul 24, 2024
1 parent 7302697 commit a2ea997
Show file tree
Hide file tree
Showing 25 changed files with 1,182 additions and 194 deletions.
10 changes: 10 additions & 0 deletions models/mdp/maze/sketch.templ
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,18 @@ module maze
endmodule
module test
var : bool init false;
[up] true -> (var'=true);
[right] true -> (var'=false);
[down] true -> (var'=false);
[left] true -> (var'=false);
endmodule
// rewards
label "what" = mod(x,2)=0;
rewards "steps"
clk=1: 1;
endrewards
Expand Down
2 changes: 1 addition & 1 deletion models/mdp/simple/sketch.templ
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mdp

module m
s : [0..4] init 0;
s : [0..3] init 0;

[up] s=0 -> (s'=1);
[down] s=0 -> (s'=2);
Expand Down
7 changes: 3 additions & 4 deletions paynt/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def hole_num_options_total(self, hole):
return self.family.holeNumOptionsTotal(hole)

def hole_set_options(self, hole, options):
assert len(options)>0
self.family.holeSetOptions(hole,options)
assert self.family.holeNumOptions(hole) == len(options)

@property
def size(self):
Expand All @@ -63,7 +61,7 @@ def size_or_order(self):

def hole_options_to_string(self, hole, options):
name = self.hole_name(hole)
labels = [self.hole_to_option_labels[hole][option] for option in options]
labels = [str(self.hole_to_option_labels[hole][option]) for option in options]
if len(labels) == 1:
return f"{name}={labels[0]}"
else:
Expand Down Expand Up @@ -159,6 +157,7 @@ def __init__(self, family = None, parent_info = None):
super().__init__(family)

self.mdp = None
self.selected_choices = None

# SMT encoding
self.encoding = None
Expand All @@ -179,11 +178,11 @@ def copy(self):

def collect_parent_info(self, specification):
pi = ParentInfo()
pi.selected_choices = self.selected_choices
pi.refinement_depth = self.refinement_depth
cr = self.analysis_result.constraints_result
pi.constraint_indices = cr.undecided_constraints if cr is not None else []
pi.splitter = self.splitter
pi.mdp = self.mdp
return pi

def encode(self, smt_solver):
Expand Down
40 changes: 40 additions & 0 deletions paynt/models/model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import stormpy

class ModelBuilder:

@classmethod
def default_builder_options(cls, specification = None):
# builder options
if specification is not None:
formulae = specification.stormpy_formulae()
builder_options = stormpy.BuilderOptions(formulae)
else:
builder_options = stormpy.BuilderOptions()
builder_options.set_build_state_valuations(True)
builder_options.set_build_with_choice_origins(True)
builder_options.set_build_all_labels(True)
builder_options.set_build_choice_labels(True)
builder_options.set_add_overlapping_guards_label(True)
builder_options.set_build_observation_valuations(True)
# builder_options.set_exploration_checks(True)
return builder_options

@classmethod
def from_jani(cls, program, specification = None):
builder_options = cls.default_builder_options(specification)
builder_options.set_build_choice_labels(False)
model = stormpy.build_sparse_model_with_options(program, builder_options)
return model

@classmethod
def from_prism(cls, program, specification = None):
assert program.model_type in [stormpy.storage.PrismModelType.MDP, stormpy.storage.PrismModelType.POMDP]
builder_options = cls.default_builder_options(specification)
model = stormpy.build_sparse_model_with_options(program, builder_options)
return model

@classmethod
def from_drn(cls, drn_path):
builder_options = stormpy.core.DirectEncodingParserOptions()
builder_options.build_choice_labels = True
return stormpy.build_model_from_drn(drn_path, builder_options)
27 changes: 1 addition & 26 deletions paynt/quotient/models.py → paynt/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,41 +8,16 @@

class Mdp:

# options for the construction of chains
builder_options = None

@classmethod
def initialize(cls, specification):
# builder options
formulae = specification.stormpy_formulae()
cls.builder_options = stormpy.BuilderOptions(formulae)
cls.builder_options.set_build_with_choice_origins(True)
cls.builder_options.set_build_state_valuations(True)
cls.builder_options.set_add_overlapping_guards_label()
cls.builder_options.set_build_observation_valuations(True)
cls.builder_options.set_build_all_labels(True)
# cls.builder_options.set_exploration_checks(True)

@classmethod
def assert_no_overlapping_guards(cls, model):
if model.labeling.contains_label("overlap_guards"):
assert model.labeling.get_states("overlap_guards").number_of_set_bits() == 0

@classmethod
def from_prism(cls, prism):
assert prism.model_type in [stormpy.storage.PrismModelType.MDP, stormpy.storage.PrismModelType.POMDP]
# TODO why do we disable choice labels here?
Mdp.builder_options.set_build_choice_labels(True)
model = stormpy.build_sparse_model_with_options(prism, Mdp.builder_options)
Mdp.builder_options.set_build_choice_labels(False)
# Mdp.assert_no_overlapping_guards(model)
return model

def __init__(self, model):
# Mdp.assert_no_overlapping_guards(model)
self.model = model
if len(model.initial_states) > 1:
logger.warning("WARNING: obtained model with multiple initial states")
self.model = model

@property
def states(self):
Expand Down
7 changes: 3 additions & 4 deletions paynt/parser/jani.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import payntbind

import paynt.verification.property
import paynt.quotient.models
import paynt.models.model_builder

import itertools
from collections import defaultdict
Expand Down Expand Up @@ -56,7 +56,6 @@ def __init__(self, prism, hole_expressions, specification, family):
p = paynt.verification.property.OptimalityProperty(prop_new,epsilon)
properties_unpacked.append(p)
self.specification = paynt.verification.property.Specification(properties_unpacked)
paynt.quotient.models.Mdp.initialize(self.specification)

# unfold holes in the program
self.hole_expressions = hole_expressions
Expand All @@ -66,8 +65,8 @@ def __init__(self, prism, hole_expressions, specification, family):
logger.debug("constructing the quotient...")

# construct the explicit quotient
quotient_mdp = stormpy.build_sparse_model_with_options(self.jani_unfolded, paynt.quotient.models.Mdp.builder_options)

quotient_mdp = paynt.models.model_builder.ModelBuilder.from_jani(self.jani_unfolded, self.specification)
# associate each action of a quotient MDP with hole options
# reconstruct choice labels from choice origins
logger.debug("associating choices of the quotient with hole assignments...")
Expand Down
6 changes: 2 additions & 4 deletions paynt/parser/prism_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import paynt.family.family
import paynt.verification.property
import paynt.parser.jani
import paynt.quotient.models
import paynt.models.model_builder

import os
import re
Expand Down Expand Up @@ -53,13 +53,11 @@ def read_prism(cls, sketch_path, properties_path, relative_error):
specification = jani_unfolder.specification
quotient_mdp = jani_unfolder.quotient_mdp
coloring = payntbind.synthesis.Coloring(family.family, quotient_mdp.nondeterministic_choice_indices, jani_unfolder.choice_to_hole_options)
paynt.quotient.models.Mdp.initialize(specification)
if prism.model_type == stormpy.storage.PrismModelType.POMDP:
obs_evaluator = payntbind.synthesis.ObservationEvaluator(prism, quotient_mdp)
quotient_mdp = payntbind.synthesis.addChoiceLabelsFromJani(quotient_mdp)
else:
paynt.quotient.models.Mdp.initialize(specification)
quotient_mdp = paynt.quotient.models.Mdp.from_prism(prism)
quotient_mdp = paynt.models.model_builder.ModelBuilder.from_prism(prism, specification)

return prism, quotient_mdp, specification, family, coloring, jani_unfolder, obs_evaluator

Expand Down
17 changes: 6 additions & 11 deletions paynt/parser/sketch.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import stormpy
import payntbind


import paynt.quotient.models
import paynt.models.model_builder

import paynt.quotient.quotient
import paynt.quotient.mdp
import paynt.quotient.pomdp
import paynt.quotient.decpomdp
import paynt.quotient.posg
Expand Down Expand Up @@ -54,12 +54,6 @@ def make_rewards_action_based(model):

class Sketch:

@classmethod
def read_drn(cls, sketch_path):
builder_options = stormpy.core.DirectEncodingParserOptions()
builder_options.build_choice_labels = True
return stormpy.build_model_from_drn(sketch_path, builder_options)

@classmethod
def load_sketch(cls, sketch_path, properties_path,
export=None, relative_error=0, precision=1e-4, constraint_bound=None):
Expand Down Expand Up @@ -91,7 +85,7 @@ def load_sketch(cls, sketch_path, properties_path,
if filetype is None:
try:
logger.info(f"assuming sketch in DRN format...")
explicit_quotient = Sketch.read_drn(sketch_path)
explicit_quotient = paynt.models.model_builder.ModelBuilder.from_drn(sketch_path)
specification = PrismParser.parse_specification(properties_path, relative_error)
filetype = "drn"
except:
Expand Down Expand Up @@ -122,7 +116,6 @@ def load_sketch(cls, sketch_path, properties_path,
assert filetype is not None, "unknow format of input file"
logger.info("sketch parsing OK")

paynt.quotient.models.Mdp.initialize(specification)
paynt.verification.property.Property.initialize()

make_rewards_action_based(explicit_quotient)
Expand Down Expand Up @@ -223,11 +216,13 @@ def build_quotient_container(cls, prism, jani_unfolder, explicit_quotient, famil
elif prism.model_type == stormpy.storage.PrismModelType.POMDP:
quotient_container = paynt.quotient.pomdp_family.PomdpFamilyQuotient(explicit_quotient, family, coloring, specification, obs_evaluator)
else:
assert explicit_quotient.is_nondeterministic_model
assert explicit_quotient.is_nondeterministic_model, "expected nondeterministic model"
if decpomdp_manager is not None and decpomdp_manager.num_agents > 1:
quotient_container = paynt.quotient.decpomdp.DecPomdpQuotient(decpomdp_manager, specification)
elif explicit_quotient.labeling.contains_label(paynt.quotient.posg.PosgQuotient.PLAYER_1_STATE_LABEL):
quotient_container = paynt.quotient.posg.PosgQuotient(explicit_quotient, specification)
elif not explicit_quotient.is_partially_observable:
quotient_container = paynt.quotient.mdp.MdpQuotient(explicit_quotient, specification)
else:
quotient_container = paynt.quotient.pomdp.PomdpQuotient(explicit_quotient, specification, decpomdp_manager)
return quotient_container
Expand Down
Loading

0 comments on commit a2ea997

Please sign in to comment.