From 03040cad0fdf14abe550a0e7a3eb2df7da3fad59 Mon Sep 17 00:00:00 2001 From: Roman Andriushchenko Date: Thu, 26 Sep 2024 15:36:24 +0200 Subject: [PATCH] DT synthesis: tree postprocessing and export --- paynt/cli.py | 12 +- paynt/parser/jani.py | 4 +- paynt/quotient/mdp.py | 146 ++++++++++++++---- paynt/synthesizer/decision_tree.py | 72 ++++++--- paynt/synthesizer/synthesizer.py | 3 + paynt/synthesizer/synthesizer_ar.py | 2 +- .../src/synthesis/quotient/ColoringSmt.cpp | 2 +- 7 files changed, 177 insertions(+), 64 deletions(-) diff --git a/paynt/cli.py b/paynt/cli.py index f3dc989f..9711cda1 100644 --- a/paynt/cli.py +++ b/paynt/cli.py @@ -110,8 +110,8 @@ def setup_logger(log_path = None): help="path to output file for SAYNT belief FSC") @click.option("--export-fsc-paynt", type=click.Path(), default=None, help="path to output file for SAYNT inductive FSC") -@click.option("--export-evaluation", type=click.Path(), default=None, - help="base filename to output evaluation result") +@click.option("--export-synthesis", type=click.Path(), default=None, + help="base filename to output synthesis result") @click.option("--mdp-split-wrt-mdp", is_flag=True, default=False, help="if set, MDP abstraction scheduler will be used for splitting, otherwise game abstraction scheduler will be used") @@ -126,8 +126,6 @@ def setup_logger(log_path = None): help="decision tree synthesis: tree depth") @click.option("--tree-enumeration", is_flag=True, default=False, help="decision tree synthesis: if set, all trees of size at most tree_depth will be enumerated") -@click.option("--add-dont-care-action", is_flag=True, default=False, - help="decision tree synthesis: if set, an explicit action simulating a random action selection will be added to each state") @click.option( "--constraint-bound", type=click.FLOAT, help="bound for creating constrained POMDP for Cassandra models", @@ -148,9 +146,9 @@ def paynt_run( fsc_synthesis, fsc_memory_size, posterior_aware, storm_pomdp, iterative_storm, get_storm_result, storm_options, prune_storm, use_storm_cutoffs, unfold_strategy_storm, - export_fsc_storm, export_fsc_paynt, export_evaluation, + export_fsc_storm, export_fsc_paynt, export_synthesis, mdp_split_wrt_mdp, mdp_discard_unreachable_choices, mdp_use_randomized_abstraction, - tree_depth, tree_enumeration, add_dont_care_action, + tree_depth, tree_enumeration, constraint_bound, ce_generator, profiling @@ -166,6 +164,7 @@ def paynt_run( # set CLI parameters paynt.quotient.quotient.Quotient.disable_expected_visits = disable_expected_visits + paynt.synthesizer.synthesizer.Synthesizer.export_synthesis_filename_base = export_synthesis paynt.synthesizer.synthesizer_cegis.SynthesizerCEGIS.conflict_generator_type = ce_generator paynt.quotient.pomdp.PomdpQuotient.initial_memory_size = fsc_memory_size paynt.quotient.pomdp.PomdpQuotient.posterior_aware = posterior_aware @@ -177,7 +176,6 @@ def paynt_run( paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_depth = tree_depth paynt.synthesizer.decision_tree.SynthesizerDecisionTree.tree_enumeration = tree_enumeration - paynt.quotient.mdp.MdpQuotient.add_dont_care_action = add_dont_care_action storm_control = None if storm_pomdp: diff --git a/paynt/parser/jani.py b/paynt/parser/jani.py index eb52b69a..baafadd4 100644 --- a/paynt/parser/jani.py +++ b/paynt/parser/jani.py @@ -219,8 +219,8 @@ def construct_edge(self, edge, substitution = None): for templ_edge_dest in edge.template_edge.destinations: assignments = templ_edge_dest.assignments.clone() if substitution is not None: - # assignments.substitute(substitution, substitute_transcendental_numbers=True) - assignments.substitute(substitution) # legacy version + assignments.substitute(substitution, substitute_transcendental_numbers=True) + # assignments.substitute(substitution) # legacy version templ_edge.add_destination(stormpy.storage.JaniTemplateEdgeDestination(assignments)) new_edge = stormpy.storage.JaniEdge( diff --git a/paynt/quotient/mdp.py b/paynt/quotient/mdp.py index fab80111..13de30bf 100644 --- a/paynt/quotient/mdp.py +++ b/paynt/quotient/mdp.py @@ -3,6 +3,7 @@ import stormpy import payntbind import json +import graphviz import logging logger = logging.getLogger(__name__) @@ -48,20 +49,6 @@ def __str__(self): domain = f"[{self.domain_min}..{self.domain_max}]" return f"{self.name}:{domain}" - @classmethod - def from_model(cls, model): - assert model.has_state_valuations(), "model has no state valuations" - sv = model.state_valuations - valuation = json.loads(str(sv.get_json(0))) - variable_name = [var_name for var_name in valuation] - state_valuations = [] - for state in range(model.nr_states): - valuation = json.loads(str(sv.get_json(state))) - valuation = [valuation[var_name] for var_name in variable_name] - state_valuations.append(valuation) - variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)] - variables = [v for v in variables if len(v.domain) > 1] - return variables @@ -69,12 +56,14 @@ class DecisionTreeNode: def __init__(self, parent): self.parent = parent - self.variable_index = None self.child_true = None self.child_false = None self.identifier = None self.holes = None - self.hole_assignment = None + + self.action = None + self.variable = None + self.variable_bound = None @property def is_terminal(self): @@ -93,6 +82,11 @@ def add_children(self): self.child_true = DecisionTreeNode(self) self.child_false = DecisionTreeNode(self) + def get_depth(self): + if self.is_terminal: + return 0 + return 1 + max([child.get_depth() for child in self.child_nodes]) + def assign_identifiers(self, identifier=0): self.identifier = identifier if self.is_terminal: @@ -109,26 +103,97 @@ def associate_holes(self, node_hole_info): self.child_false.associate_holes(node_hole_info) def associate_assignment(self, assignment): - self.hole_assignment = [assignment.hole_options(hole)[0] for hole in self.holes] + hole_assignment = [assignment.hole_options(hole)[0] for hole in self.holes] if self.is_terminal: + self.action = hole_assignment[0] return + + self.variable = hole_assignment[0] + self.variable_bound = hole_assignment[self.variable+1] + self.child_true.associate_assignment(assignment) self.child_false.associate_assignment(assignment) def apply_hint(self, subfamily, tree_hint): if self.is_terminal or tree_hint.is_terminal: return - for hole_index,option in enumerate(tree_hint.hole_assignment): - hole = self.holes[hole_index] - subfamily.hole_set_options(hole,[option]) + + variable_hint = tree_hint.variable + subfamily.hole_set_options(self.holes[0],[variable_hint]) + subfamily.hole_set_options(self.holes[variable_hint+1],[tree_hint.variable_bound]) self.child_true.apply_hint(subfamily,tree_hint.child_true) self.child_false.apply_hint(subfamily,tree_hint.child_false) + def simplify(self, variables, state_valuations): + if self.is_terminal: + return + + bound = variables[self.variable].domain[self.variable_bound] + state_valuations_true = [valuation for valuation in state_valuations if valuation[self.variable] <= bound] + state_valuations_false = [valuation for valuation in state_valuations if valuation[self.variable] > bound] + child_skip = None + if len(state_valuations_true) == 0: + child_skip = self.child_false + elif len(state_valuations_false) == 0: + child_skip = self.child_true + if child_skip is not None: + self.variable = child_skip.variable + self.variable_bound = child_skip.variable_bound + self.action = child_skip.action + self.child_true = child_skip.child_true + self.child_false = child_skip.child_false + self.simplify(variables,state_valuations) + return + + self.child_true.simplify(variables, state_valuations_true) + self.child_false.simplify(variables, state_valuations_false) + if not self.is_terminal and self.child_true.is_terminal and self.child_false.is_terminal and self.child_true.action == self.child_false.action: + self.variable = self.variable_bound = None + self.action = self.child_true.action + self.child_true = self.child_false = None + + def to_string(self, variables, action_labels, indent_level=0, indent_size=2): + indent = " "*indent_level*indent_size + if self.is_terminal: + return indent + f"{action_labels[self.action]}" + "\n" + var = variables[self.variable] + s = "" + s += indent + f"if {var.name}<={var.domain[self.variable_bound]}:" + "\n" + s += self.child_true.to_string(variables,action_labels,indent_level+1) + s += indent + f"else:" + "\n" + s += self.child_false.to_string(variables,action_labels,indent_level+1) + return s + + @property + def graphviz_id(self): + return str(self.identifier) + + def to_graphviz(self, graphviz_tree, variables, action_labels): + if not self.is_terminal: + for child in self.child_nodes: + child.to_graphviz(graphviz_tree,variables,action_labels) + + if self.is_terminal: + node_label = action_labels[self.action] + else: + var = variables[self.variable] + node_label = f"{var.name}<={var.domain[self.variable_bound]}" + + graphviz_tree.node(self.graphviz_id, label=node_label, shape="box", style="rounded", margin="0.05,0.05") + if not self.is_terminal: + graphviz_tree.edge(self.graphviz_id,self.child_true.graphviz_id,label="True") + graphviz_tree.edge(self.graphviz_id,self.child_false.graphviz_id,label="False") + + class DecisionTree: - def __init__(self, model): - self.variables = Variable.from_model(model) + def __init__(self, quotient, variable_name, state_valuations): + self.quotient = quotient + self.state_valuations = state_valuations + variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)] + variables = [v for v in variables if len(v.domain) > 1] + self.variables = variables logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}") self.reset() @@ -142,6 +207,9 @@ def set_depth(self, depth:int): node.add_children() self.root.assign_identifiers() + def get_depth(self): + return self.root.get_depth() + def collect_nodes(self, node_condition=None): if node_condition is None: node_condition = lambda node : True @@ -170,25 +238,45 @@ def to_list(self): node_info[node.identifier] = (parent,child_true,child_false) return node_info + def simplify(self): + self.root.simplify(self.variables, self.state_valuations) + + def to_string(self): + return self.root.to_string(self.variables,self.quotient.action_labels) + + def to_graphviz(self): + logging.getLogger("graphviz").setLevel(logging.WARNING) + logging.getLogger("graphviz.sources").setLevel(logging.ERROR) + graphviz_tree = graphviz.Digraph(comment="decision tree") + self.root.to_graphviz(graphviz_tree,self.variables,self.quotient.action_labels) + return graphviz_tree -class MdpQuotient(paynt.quotient.quotient.Quotient): - # if set, an explicit action simulating a random action selection will be added to each state - add_dont_care_action = False +class MdpQuotient(paynt.quotient.quotient.Quotient): def __init__(self, mdp, specification): super().__init__(specification=specification) updated = payntbind.synthesis.restoreActionsInAbsorbingStates(mdp) if updated is not None: mdp = updated - # action_labels, _ payntbind.synthesis.extractActionLabels(mdp) - if MdpQuotient.add_dont_care_action: + action_labels,_ = payntbind.synthesis.extractActionLabels(mdp) + if "__random__" not in action_labels: + logger.debug("adding explicit don't-care action to every state...") mdp = payntbind.synthesis.addDontCareAction(mdp) - # stormpy.export_to_drn(mdp, sketch_path+".drn") self.quotient_mdp = mdp self.choice_destinations = payntbind.synthesis.computeChoiceDestinations(mdp) self.action_labels,self.choice_to_action = payntbind.synthesis.extractActionLabels(mdp) - self.decision_tree = DecisionTree(mdp) + + assert mdp.has_state_valuations(), "model has no state valuations" + sv = mdp.state_valuations + valuation = json.loads(str(sv.get_json(0))) + variable_name = [var_name for var_name in valuation] + state_valuations = [] + for state in range(mdp.nr_states): + valuation = json.loads(str(sv.get_json(state))) + valuation = [valuation[var_name] for var_name in variable_name] + state_valuations.append(valuation) + self.decision_tree = DecisionTree(self,variable_name,state_valuations) self.coloring = None self.family = None diff --git a/paynt/synthesizer/decision_tree.py b/paynt/synthesizer/decision_tree.py index 2d8400a2..55fee6df 100644 --- a/paynt/synthesizer/decision_tree.py +++ b/paynt/synthesizer/decision_tree.py @@ -19,6 +19,8 @@ class SynthesizerDecisionTree(paynt.synthesizer.synthesizer_ar.SynthesizerAR): def __init__(self, *args): super().__init__(*args) + self.best_tree = None + self.best_tree_value = None @property def method_name(self): @@ -102,22 +104,37 @@ def counters_print(self): logger.info(f"harmonizations succeeded: {self.num_harmonization_succeeded}") print() + def export_decision_tree(self, decision_tree, export_filename_base): + tree = decision_tree.to_graphviz() + # tree_filename = export_filename_base + ".dot" + # with open(tree_filename, 'w') as file: + # file.write(tree.source) + # logger.info(f"exported decision tree to {tree_filename}") + + tree_visualization_filename = export_filename_base + ".png" + tree.render(export_filename_base, format="png", cleanup=True) # using export_filename_base since graphviz appends .png by default + logger.info(f"exported decision tree visualization to {tree_visualization_filename}") + + def synthesize_tree(self, depth:int): self.counters_reset() self.quotient.set_depth(depth) + self.best_assignment = self.best_assignment_value = None self.synthesize(keep_optimum=True) + if self.best_assignment is not None: + self.quotient.decision_tree.root.associate_assignment(self.best_assignment) + self.best_tree = self.quotient.decision_tree + self.best_tree_value = self.best_assignment_value + self.best_assignment = self.best_assignment_value = None self.counters_print() def synthesize_tree_sequence(self, opt_result_value): - tree_hint = None + self.best_tree = self.best_tree_value = None + global_timeout = paynt.utils.timer.GlobalTimer.global_timer.time_limit_seconds - if global_timeout is None: - global_timeout = 300 + if global_timeout is None: global_timeout = 1800 depth_timeout = global_timeout / 2 / SynthesizerDecisionTree.tree_depth for depth in range(SynthesizerDecisionTree.tree_depth+1): - print() - print("DEPTH = ", depth) - print() self.quotient.set_depth(depth) best_assignment_old = self.best_assignment @@ -131,9 +148,9 @@ def synthesize_tree_sequence(self, opt_result_value): self.synthesis_timer.start() families = [family] - if SynthesizerDecisionTree.use_tree_hint and tree_hint is not None: + if SynthesizerDecisionTree.use_tree_hint and self.best_tree is not None: subfamily = family.copy() - self.quotient.decision_tree.root.apply_hint(subfamily,tree_hint) + self.quotient.decision_tree.root.apply_hint(subfamily,self.best_tree) families = [subfamily,family] for family in families: @@ -156,8 +173,9 @@ def synthesize_tree_sequence(self, opt_result_value): if abs( (self.best_assignment_value-opt_result_value)/opt_result_value ) < 1e-3: break - tree_hint = self.quotient.decision_tree.root - tree_hint.associate_assignment(self.best_assignment) + self.best_tree = self.quotient.decision_tree.root + self.best_tree.associate_assignment(self.best_assignment) + self.best_tree_value = self.best_assignment_value if self.resource_limit_reached(): break @@ -175,29 +193,35 @@ def run(self, optimum_threshold=None): if self.quotient.specification.optimality.maximizing == mc_result_positive: epsilon *= -1 # optimum_threshold = opt_result_value * (1 + epsilon) - self.set_optimality_threshold(optimum_threshold) - self.best_assignment = None - self.best_assignment_value = None + self.best_tree = None + self.best_tree_value = None if not SynthesizerDecisionTree.tree_enumeration: self.synthesize_tree(SynthesizerDecisionTree.tree_depth) else: self.synthesize_tree_sequence(opt_result_value) logger.info(f"the optimal scheduler has value: {opt_result_value}") - if self.best_assignment is not None: - logger.info(f"admissible assignment found: {self.best_assignment}") - if self.quotient.specification.has_optimality: - logger.info(f"best assignment has value {self.quotient.specification.optimality.optimum}") + if self.best_tree is None: + logger.info("no admissible tree found") else: - logger.info("no admissible assignment found") + self.best_tree.simplify() + logger.info(f"printing synthesized tree below:") + print(self.best_tree.to_string()) + + depth = self.best_tree.get_depth() + if self.quotient.specification.has_optimality: + logger.info(f"synthesized tree has value {self.best_tree_value}") + num_nodes = len(self.best_tree.collect_nonterminals()) + logger.info(f"synthesized tree is of depth {depth} and has {num_nodes} decision nodes") + if self.export_synthesis_filename_base is not None: + self.export_decision_tree(self.best_tree, self.export_synthesis_filename_base) time_total = paynt.utils.timer.GlobalTimer.read() - # logger.info(f"synthesis time: {round(time_total, 2)} s") - print() - for name,time in self.quotient.coloring.getProfilingInfo(): - time_percent = round(time/time_total*100,1) - print(f"{name} = {time} s ({time_percent} %)") + # print() + # for name,time in self.quotient.coloring.getProfilingInfo(): + # time_percent = round(time/time_total*100,1) + # print(f"{name} = {time} s ({time_percent} %)") - return self.best_assignment + return self.best_tree diff --git a/paynt/synthesizer/synthesizer.py b/paynt/synthesizer/synthesizer.py index 9f11e74f..6e9a4681 100644 --- a/paynt/synthesizer/synthesizer.py +++ b/paynt/synthesizer/synthesizer.py @@ -16,6 +16,9 @@ def __init__(self, family, value, sat, policy): class Synthesizer: + # base filename (i.e. without extension) to export synthesis result + export_synthesis_filename_base = None + @staticmethod def choose_synthesizer(quotient, method, fsc_synthesis=False, storm_control=None): diff --git a/paynt/synthesizer/synthesizer_ar.py b/paynt/synthesizer/synthesizer_ar.py index f09a412d..1af7788c 100644 --- a/paynt/synthesizer/synthesizer_ar.py +++ b/paynt/synthesizer/synthesizer_ar.py @@ -100,7 +100,7 @@ def update_optimum(self, family): self.quotient.specification.optimality.update_optimum(iv) self.best_assignment = ia self.best_assignment_value = iv - logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds") + # logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds") if isinstance(self.quotient, paynt.quotient.pomdp.PomdpQuotient): self.stat.new_fsc_found(family.analysis_result.improving_value, ia, self.quotient.policy_size(ia)) diff --git a/payntbind/src/synthesis/quotient/ColoringSmt.cpp b/payntbind/src/synthesis/quotient/ColoringSmt.cpp index 2bbc9fd4..29bb0441 100644 --- a/payntbind/src/synthesis/quotient/ColoringSmt.cpp +++ b/payntbind/src/synthesis/quotient/ColoringSmt.cpp @@ -104,7 +104,7 @@ ColoringSmt::ColoringSmt( break; } } - STORM_LOG_THROW(domain_option_found, storm::exceptions::UnexpectedException, "Hole option not found."); + STORM_LOG_THROW(domain_option_found, storm::exceptions::UnexpectedException, "hole option not found."); } }