diff --git a/paynt/quotient/pomdp.py b/paynt/quotient/pomdp.py index f3a4c0e6a..a5541693b 100644 --- a/paynt/quotient/pomdp.py +++ b/paynt/quotient/pomdp.py @@ -217,11 +217,7 @@ def set_memory_from_result_new(self, obs_memory_dict, obs_memory_dict_cutoff, me def create_coloring(self): - - # short aliases - pm = self.pomdp_manager - pomdp = self.pomdp - mdp = self.quotient_mdp + logger.debug("creating coloring ...") # create holes all_holes = paynt.family.family.Family() @@ -235,7 +231,7 @@ def create_coloring(self): hole_indices = [] num_actions = self.actions_at_observation[obs] if num_actions > 1: - option_labels = [str(labels) for labels in self.action_labels_at_observation[obs]] + option_labels = self.action_labels_at_observation[obs] for mem in range(self.observation_memory_size[obs]): hole_indices.append(all_holes.num_holes) name = self.create_hole_name(obs,mem,True) @@ -245,7 +241,7 @@ def create_coloring(self): # memory holes hole_indices = [] - num_updates = pm.max_successor_memory_size[obs] + num_updates = self.pomdp_manager.max_successor_memory_size[obs] if num_updates > 1: option_labels = [str(x) for x in range(num_updates)] for mem in range(self.observation_memory_size[obs]): @@ -256,15 +252,20 @@ def create_coloring(self): self.observation_memory_holes.append(hole_indices) # create the coloring + assert self.pomdp_manager.num_holes == all_holes.num_holes + row_action_hole = self.pomdp_manager.row_action_hole + row_memory_hole = self.pomdp_manager.row_memory_hole + row_action_option = self.pomdp_manager.row_action_option + row_memory_option = self.pomdp_manager.row_memory_option choice_to_hole_options = [] - for action in range(mdp.nr_choices): + for action in range(self.quotient_mdp.nr_choices): hole_options = [] - h = pm.row_action_hole[action] - if h != pm.num_holes: - hole_options.append( (h,pm.row_action_option[action]) ) - h = pm.row_memory_hole[action] - if h != pm.num_holes: - hole_options.append( (h,pm.row_memory_option[action]) ) + hole = row_action_hole[action] + if hole != all_holes.num_holes: + hole_options.append( (hole,row_action_option[action]) ) + hole = row_memory_hole[action] + if hole != all_holes.num_holes: + hole_options.append( (hole,row_memory_option[action]) ) choice_to_hole_options.append(hole_options) return all_holes, choice_to_hole_options