Skip to content

Commit

Permalink
DT synthesis: map scheduler to decision tree
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 27, 2024
1 parent 0ac2b44 commit 5762814
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 46 deletions.
8 changes: 7 additions & 1 deletion paynt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def setup_logger(log_path = None):
help="decision tree synthesis: tree depth")
@click.option("--tree-enumeration", is_flag=True, default=False,
help="decision tree synthesis: if set, all trees of size at most tree_depth will be enumerated")
@click.option("--tree-map-scheduler", type=click.Path(), default=None,
help="decision tree synthesis: path to a scheduler to be mapped to a decision tree")
@click.option("--add-dont-care-action", is_flag=True, default=False,
help="decision tree synthesis: # if set, an explicit action executing a random choice of an available action will be added to each state")

@click.option(
"--constraint-bound", type=click.FLOAT, help="bound for creating constrained POMDP for Cassandra models",
Expand All @@ -148,7 +152,7 @@ def paynt_run(
use_storm_cutoffs, unfold_strategy_storm,
export_fsc_storm, export_fsc_paynt, export_synthesis,
mdp_split_wrt_mdp, mdp_discard_unreachable_choices, mdp_use_randomized_abstraction,
tree_depth, tree_enumeration,
tree_depth, tree_enumeration, tree_map_scheduler, add_dont_care_action,
constraint_bound,
ce_generator,
profiling
Expand Down Expand Up @@ -176,6 +180,8 @@ def paynt_run(

paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_depth = tree_depth
paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_enumeration = tree_enumeration
paynt.synthesizer.decision_tree.SynthesizerDecisionTree.scheduler_path = tree_map_scheduler
paynt.quotient.mdp.MdpQuotient.add_dont_care_action = add_dont_care_action

storm_control = None
if storm_pomdp:
Expand Down
56 changes: 47 additions & 9 deletions paynt/quotient/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@ def __init__(self, quotient, variables, state_valuations):
self.quotient = quotient
self.state_valuations = state_valuations
self.variables = variables
logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}")
self.reset()

def reset(self):
Expand Down Expand Up @@ -252,18 +251,22 @@ def to_graphviz(self):

class MdpQuotient(paynt.quotient.quotient.Quotient):

# if true, an explicit action executing a random choice of an available action will be added to each state
add_dont_care_action = False

def __init__(self, mdp, specification):
super().__init__(specification=specification)
updated = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp)
if updated is not None: mdp = updated
action_labels,_ = payntbind.synthesis.extractActionLabels(mdp)
if "__random__" not in action_labels:
if "__random__" not in action_labels and MdpQuotient.add_dont_care_action:
logger.debug("adding explicit don't-care action to every state...")
mdp = payntbind.synthesis.addDontCareAction(mdp)

self.quotient_mdp = mdp
self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(mdp)
self.action_labels,self.choice_to_action = payntbind.synthesis.extractActionLabels(mdp)
logger.info(f"MDP has {len(self.action_labels)} actions")

assert mdp.has_state_valuations(), "model has no state valuations"
sv = mdp.state_valuations
Expand All @@ -274,23 +277,58 @@ def __init__(self, mdp, specification):
valuation = json.loads(str(sv.get_json(state)))
valuation = [valuation[var_name] for var_name in variable_name]
state_valuations.append(valuation)
self.state_valuations = state_valuations
variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)]
self.variables = [v for v in variables if len(v.domain) > 1]
variable_mask = [len(v.domain) > 1 for v in variables]
variables = [v for index,v in enumerate(variables) if variable_mask[index]]
for state,valuation in enumerate(state_valuations):
state_valuations[state] = [value for index,value in enumerate(valuation) if variable_mask[index]]
self.variables = variables
self.state_valuations = state_valuations
logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}")

self.decision_tree = None
self.coloring = None
self.family = None
self.splitter_count = None

def decide(self, node, var_name):
node.set_variable_by_name(var_name,self.decision_tree)
def state_valuation_to_state(self, valuation):
valuation = [valuation[v.name] for v in self.variables]
for state,state_valuation in enumerate(self.state_valuations):
if valuation == state_valuation:
return state
else:
assert False, "state valuation not found"

def scheduler_json_to_choices(self, scheduler_json):
ndi = self.quotient_mdp.nondeterministic_choice_indices.copy()
assert self.quotient_mdp.nr_states == len(scheduler_json)
state_to_choice = self.empty_scheduler()
for state_decision in scheduler_json:
state = self.state_valuation_to_state(state_decision["s"])
actions = state_decision["c"]
assert len(actions) == 1
action_labels = actions[0]["labels"]
assert len(action_labels) <= 1
if len(action_labels) == 0:
state_to_choice[state] = ndi[state]
continue
action = self.action_labels.index(action_labels[0])
# find a choice that executes this action
for choice in range(ndi[state],ndi[state+1]):
if self.choice_to_action[choice] == action:
state_to_choice[state] = choice
break
else:
assert False, "action is not available in the state"
state_to_choice = self.discard_unreachable_choices(state_to_choice)
choices = self.state_to_choice_to_choices(state_to_choice)
return choices

def reset_tree(self, depth):
def reset_tree(self, depth, disable_counterexamples=False):
'''
Rebuild the decision tree template, the design space and the coloring.
'''
logger.debug(f"synthesizing tree of depth {depth}")
logger.debug(f"building tree of depth {depth}")
self.decision_tree = DecisionTree(self,self.variables,self.state_valuations)
self.decision_tree.set_depth(depth)

Expand All @@ -299,7 +337,7 @@ def reset_tree(self, depth):
variable_name = [v.name for v in variables]
variable_domain = [v.domain for v in variables]
tree_list = self.decision_tree.to_list()
self.coloring = payntbind.synthesis.ColoringSmt(self.quotient_mdp, variable_name, variable_domain, tree_list, False)
self.coloring = payntbind.synthesis.ColoringSmt(self.quotient_mdp, variable_name, variable_domain, tree_list, disable_counterexamples)

# reconstruct the family
hole_info = self.coloring.getFamilyInfo()
Expand Down
8 changes: 4 additions & 4 deletions paynt/quotient/quotient.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def discard_unreachable_choices(self, state_to_choice):
state_queue.append(dst)
return state_to_choice_reachable

def scheduler_to_state_to_choice(self, mdp, scheduler, discard_unreachable_choices=True):
state_to_quotient_choice = payntbind.synthesis.schedulerToStateToGlobalChoice(scheduler, mdp.model, mdp.quotient_choice_map)
def scheduler_to_state_to_choice(self, submdp, scheduler, discard_unreachable_choices=True):
state_to_quotient_choice = payntbind.synthesis.schedulerToStateToGlobalChoice(scheduler, submdp.model, submdp.quotient_choice_map)
state_to_choice = self.empty_scheduler()
for state in range(mdp.model.nr_states):
for state in range(submdp.model.nr_states):
quotient_choice = state_to_quotient_choice[state]
quotient_state = mdp.quotient_state_map[state]
quotient_state = submdp.quotient_state_map[state]
state_to_choice[quotient_state] = quotient_choice
if discard_unreachable_choices:
state_to_choice = self.discard_unreachable_choices(state_to_choice)
Expand Down
90 changes: 69 additions & 21 deletions paynt/synthesizer/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import stormpy
import payntbind

import json

import logging
logger = logging.getLogger(__name__)

Expand All @@ -14,8 +16,8 @@ class SynthesizerDecisionTree(paynt.synthesizer.synthesizer_ar.SynthesizerAR):
tree_depth = 0
# if set, all trees of size at most tree_depth will be enumerated
tree_enumeration = False
# if set, the optimal k-tree will be used to jumpstart the synthesis of the (k+1)-tree
use_tree_hint = True
# path to a scheduler to be mapped to a decision tree
scheduler_path = None

def __init__(self, *args):
super().__init__(*args)
Expand Down Expand Up @@ -149,7 +151,7 @@ def synthesize_tree_sequence(self, opt_result_value):
self.synthesis_timer.start()
families = [family]

if SynthesizerDecisionTree.use_tree_hint and self.best_tree is not None:
if self.best_tree is not None:
subfamily = family.copy()
self.quotient.decision_tree.root.apply_hint(subfamily,self.best_tree.root)
families = [subfamily,family]
Expand Down Expand Up @@ -181,41 +183,87 @@ def synthesize_tree_sequence(self, opt_result_value):
if self.resource_limit_reached():
break

def map_scheduler(self, scheduler_choices, opt_result_value):
# use counterexamples iff a dont' care action exists
disable_counterexamples = "__random__" not in self.quotient.action_labels
self.counters_reset()
for depth in range(SynthesizerDecisionTree.tree_depth+1):
self.quotient.reset_tree(depth,disable_counterexamples=disable_counterexamples)
family = self.quotient.family
self.quotient.build(family)
family.analysis_result = self.quotient.build_unsat_result()
best_assignment_old = self.best_assignment

consistent,hole_selection = self.quotient.are_choices_consistent(scheduler_choices, family)
if consistent:
self.verify_hole_selection(family,hole_selection)
elif not disable_counterexamples:
harmonizing_hole = [hole for hole,options in enumerate(hole_selection) if len(options)>1][0]
selection_1 = hole_selection.copy(); selection_1[harmonizing_hole] = [selection_1[harmonizing_hole][0]]
selection_2 = hole_selection.copy(); selection_2[harmonizing_hole] = [selection_2[harmonizing_hole][1]]
for selection in [selection_1,selection_2]:
self.verify_hole_selection(family,selection)

new_assignment_synthesized = self.best_assignment != best_assignment_old
if new_assignment_synthesized:
self.best_tree = self.quotient.decision_tree
self.best_tree.root.associate_assignment(self.best_assignment)
self.best_tree_value = self.best_assignment_value
if abs( (self.best_assignment_value-opt_result_value)/opt_result_value ) < 1e-4:
break

if self.resource_limit_reached():
break

# self.counters_print()

def run(self, optimum_threshold=None):
paynt_mdp = paynt.models.models.Mdp(self.quotient.quotient_mdp)
mc_result = paynt_mdp.model_check_property(self.quotient.get_property())

scheduler_choices = None
if SynthesizerDecisionTree.scheduler_path is None:
paynt_mdp = paynt.models.models.Mdp(self.quotient.quotient_mdp)
mc_result = paynt_mdp.model_check_property(self.quotient.get_property())
else:
opt_result_value = None
with open(SynthesizerDecisionTree.scheduler_path, 'r') as f:
scheduler_json = json.load(f)
scheduler_choices = self.quotient.scheduler_json_to_choices(scheduler_json)
submdp = self.quotient.build_from_choice_mask(scheduler_choices)
mc_result = submdp.model_check_property(self.quotient.get_property())
opt_result_value = mc_result.value
logger.info(f"the optimal scheduler has value: {opt_result_value}")

if self.quotient.specification.has_optimality:
epsilon = 1e-1
mc_result_positive = opt_result_value > 0
if self.quotient.specification.optimality.maximizing == mc_result_positive:
epsilon *= -1
# optimum_threshold = opt_result_value * (1 + epsilon)
self.set_optimality_threshold(optimum_threshold)

self.best_tree = None
self.best_tree_value = None
if not SynthesizerDecisionTree.tree_enumeration:
self.synthesize_tree(SynthesizerDecisionTree.tree_depth)
self.best_assignment = self.best_assignment_value = None
self.best_tree = self.best_tree_value = None
if scheduler_choices is not None:
self.map_scheduler(scheduler_choices, opt_result_value)
else:
self.synthesize_tree_sequence(opt_result_value)
if self.quotient.specification.has_optimality:
epsilon = 1e-1
mc_result_positive = opt_result_value > 0
if self.quotient.specification.optimality.maximizing == mc_result_positive:
epsilon *= -1
# optimum_threshold = opt_result_value * (1 + epsilon)
self.set_optimality_threshold(optimum_threshold)

if not SynthesizerDecisionTree.tree_enumeration:
self.synthesize_tree(SynthesizerDecisionTree.tree_depth)
else:
self.synthesize_tree_sequence(opt_result_value)

logger.info(f"the optimal scheduler has value: {opt_result_value}")
if self.best_tree is None:
logger.info("no admissible tree found")
else:
self.best_tree.simplify()
logger.info(f"printing synthesized tree below:")
logger.info(f"printing the synthesized tree below:")
print(self.best_tree.to_string())

depth = self.best_tree.get_depth()
if self.quotient.specification.has_optimality:
logger.info(f"synthesized tree has value {self.best_tree_value}")
logger.info(f"the synthesized tree has value {self.best_tree_value}")
num_nodes = len(self.best_tree.collect_nonterminals())
logger.info(f"synthesized tree is of depth {depth} and has {num_nodes} decision nodes")
logger.info(f"the synthesized tree is of depth {depth} and has {num_nodes} decision nodes")
if self.export_synthesis_filename_base is not None:
self.export_decision_tree(self.best_tree, self.export_synthesis_filename_base)
time_total = paynt.utils.timer.GlobalTimer.read()
Expand Down
13 changes: 4 additions & 9 deletions payntbind/src/synthesis/quotient/ColoringSmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ ColoringSmt<ValueType>::ColoringSmt(
std::vector<std::string> const& variable_name,
std::vector<std::vector<int64_t>> const& variable_domain,
std::vector<std::tuple<uint64_t,uint64_t,uint64_t>> const& tree_list,
bool one_consistency_check
bool disable_counterexamples
) : initial_state(*model.getInitialStates().begin()),
row_groups(model.getNondeterministicChoiceIndices()),
choice_destinations(synthesis::computeChoiceDestinations(model)),
choice_to_action(synthesis::extractActionLabels(model).second),
variable_name(variable_name), variable_domain(variable_domain),
solver(ctx), harmonizing_variable(ctx), one_consistency_check(one_consistency_check) {
solver(ctx), harmonizing_variable(ctx), disable_counterexamples(disable_counterexamples) {

timers[__FUNCTION__].start();

Expand Down Expand Up @@ -165,7 +165,7 @@ ColoringSmt<ValueType>::ColoringSmt(
}
timers["ColoringSmt:: create choice colors"].stop();

if(one_consistency_check) {
if(disable_counterexamples) {
timers[__FUNCTION__].stop();
return;
}
Expand Down Expand Up @@ -438,7 +438,7 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
return std::make_pair(true,hole_options_vector);
}

if(one_consistency_check) {
if(disable_counterexamples) {
solver.pop();
solver.pop();
timers[__FUNCTION__].stop();
Expand All @@ -454,7 +454,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
BitVector state_reached(numStates(),false);
state_reached.set(initial_state,true);
consistent = true;
uint64_t num_choices_added = 0;
while(consistent) {
STORM_LOG_THROW(not unexplored_states.empty(), storm::exceptions::UnexpectedException, "all states explored but UNSAT core not found");
uint64_t state = unexplored_states.front(); unexplored_states.pop();
Expand All @@ -466,7 +465,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
const char *label = choice_path_label[choice][path].c_str();
solver.add(choice_path_expresssion[choice][path], label);
}
// std::cout << "(2) adding choice " << (++num_choices_added) << std::endl;
consistent = check();
if(not consistent) {
break;
Expand Down Expand Up @@ -539,9 +537,6 @@ std::pair<bool,std::vector<std::vector<uint64_t>>> ColoringSmt<ValueType>::areCh
}





template class ColoringSmt<>;

}
4 changes: 2 additions & 2 deletions payntbind/src/synthesis/quotient/ColoringSmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ColoringSmt {
std::vector<std::string> const& variable_name,
std::vector<std::vector<int64_t>> const& variable_domain,
std::vector<std::tuple<uint64_t,uint64_t,uint64_t>> const& tree_list,
bool one_consistency_check = false
bool disable_counterexamples = false
);

/** For each hole, get a list of name-type pairs. */
Expand Down Expand Up @@ -134,7 +134,7 @@ class ColoringSmt {
void loadUnsatCore(z3::expr_vector const& unsat_core_expr, Family const& subfamily);

/** If true, the object will be setup for one consistency check. */
bool one_consistency_check;
bool disable_counterexamples;

};

Expand Down

0 comments on commit 5762814

Please sign in to comment.