Skip to content

Commit

Permalink
DT synthesis: fix iterative algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Sep 26, 2024
1 parent 1b30597 commit 0ac2b44
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
18 changes: 10 additions & 8 deletions paynt/quotient/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,9 @@ def to_graphviz(self, graphviz_tree, variables, action_labels):

class DecisionTree:

def __init__(self, quotient, variable_name, state_valuations):
def __init__(self, quotient, variables, state_valuations):
self.quotient = quotient
self.state_valuations = state_valuations
variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)]
variables = [v for v in variables if len(v.domain) > 1]
self.variables = variables
logger.debug(f"found the following {len(self.variables)} variables: {[str(v) for v in self.variables]}")
self.reset()
Expand Down Expand Up @@ -276,20 +274,24 @@ def __init__(self, mdp, specification):
valuation = json.loads(str(sv.get_json(state)))
valuation = [valuation[var_name] for var_name in variable_name]
state_valuations.append(valuation)
self.decision_tree = DecisionTree(self,variable_name,state_valuations)
self.state_valuations = state_valuations
variables = [Variable(var,var_name,state_valuations) for var,var_name in enumerate(variable_name)]
self.variables = [v for v in variables if len(v.domain) > 1]

self.decision_tree = None
self.coloring = None
self.family = None
self.splitter_count = None

def decide(self, node, var_name):
node.set_variable_by_name(var_name,self.decision_tree)

'''
Build the design space and coloring corresponding to the current decision tree.
'''
def set_depth(self, depth):
def reset_tree(self, depth):
'''
Rebuild the decision tree template, the design space and the coloring.
'''
logger.debug(f"synthesizing tree of depth {depth}")
self.decision_tree = DecisionTree(self,self.variables,self.state_valuations)
self.decision_tree.set_depth(depth)

# logger.debug("building coloring...")
Expand Down
15 changes: 8 additions & 7 deletions paynt/synthesizer/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def export_decision_tree(self, decision_tree, export_filename_base):

def synthesize_tree(self, depth:int):
self.counters_reset()
self.quotient.set_depth(depth)
self.quotient.reset_tree(depth)
self.best_assignment = self.best_assignment_value = None
self.synthesize(keep_optimum=True)
if self.best_assignment is not None:
Expand All @@ -135,7 +135,8 @@ def synthesize_tree_sequence(self, opt_result_value):
if global_timeout is None: global_timeout = 1800
depth_timeout = global_timeout / 2 / SynthesizerDecisionTree.tree_depth
for depth in range(SynthesizerDecisionTree.tree_depth+1):
self.quotient.set_depth(depth)
print()
self.quotient.reset_tree(depth)
best_assignment_old = self.best_assignment

family = self.quotient.family
Expand All @@ -150,7 +151,7 @@ def synthesize_tree_sequence(self, opt_result_value):

if SynthesizerDecisionTree.use_tree_hint and self.best_tree is not None:
subfamily = family.copy()
self.quotient.decision_tree.root.apply_hint(subfamily,self.best_tree)
self.quotient.decision_tree.root.apply_hint(subfamily,self.best_tree.root)
families = [subfamily,family]

for family in families:
Expand All @@ -170,13 +171,13 @@ def synthesize_tree_sequence(self, opt_result_value):
result = dtmc.check_specification(self.quotient.specification)
logger.info(f"double-checking specification satisfiability: {result}")

self.best_tree = self.quotient.decision_tree
self.best_tree.root.associate_assignment(self.best_assignment)
self.best_tree_value = self.best_assignment_value

if abs( (self.best_assignment_value-opt_result_value)/opt_result_value ) < 1e-3:
break

self.best_tree = self.quotient.decision_tree.root
self.best_tree.associate_assignment(self.best_assignment)
self.best_tree_value = self.best_assignment_value

if self.resource_limit_reached():
break

Expand Down
2 changes: 1 addition & 1 deletion paynt/synthesizer/synthesizer_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def update_optimum(self, family):
self.quotient.specification.optimality.update_optimum(iv)
self.best_assignment = ia
self.best_assignment_value = iv
# logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds")
logger.info(f"value {round(iv,4)} achieved after {round(paynt.utils.timer.GlobalTimer.read(),2)} seconds")
if isinstance(self.quotient, paynt.quotient.pomdp.PomdpQuotient):
self.stat.new_fsc_found(family.analysis_result.improving_value, ia, self.quotient.policy_size(ia))

Expand Down

0 comments on commit 0ac2b44

Please sign in to comment.