Skip to content

Commit

Permalink
optimize POMDP coloring creation
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Andriushchenko committed Mar 13, 2024
1 parent d6152ac commit 8efef06
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions paynt/quotient/pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ def set_memory_from_result_new(self, obs_memory_dict, obs_memory_dict_cutoff, me


def create_coloring(self):

# short aliases
pm = self.pomdp_manager
pomdp = self.pomdp
mdp = self.quotient_mdp
logger.debug("creating coloring ...")

# create holes
all_holes = paynt.family.family.Family()
Expand All @@ -235,7 +231,7 @@ def create_coloring(self):
hole_indices = []
num_actions = self.actions_at_observation[obs]
if num_actions > 1:
option_labels = [str(labels) for labels in self.action_labels_at_observation[obs]]
option_labels = self.action_labels_at_observation[obs]
for mem in range(self.observation_memory_size[obs]):
hole_indices.append(all_holes.num_holes)
name = self.create_hole_name(obs,mem,True)
Expand All @@ -245,7 +241,7 @@ def create_coloring(self):

# memory holes
hole_indices = []
num_updates = pm.max_successor_memory_size[obs]
num_updates = self.pomdp_manager.max_successor_memory_size[obs]
if num_updates > 1:
option_labels = [str(x) for x in range(num_updates)]
for mem in range(self.observation_memory_size[obs]):
Expand All @@ -256,15 +252,20 @@ def create_coloring(self):
self.observation_memory_holes.append(hole_indices)

# create the coloring
assert self.pomdp_manager.num_holes == all_holes.num_holes
row_action_hole = self.pomdp_manager.row_action_hole
row_memory_hole = self.pomdp_manager.row_memory_hole
row_action_option = self.pomdp_manager.row_action_option
row_memory_option = self.pomdp_manager.row_memory_option
choice_to_hole_options = []
for action in range(mdp.nr_choices):
for action in range(self.quotient_mdp.nr_choices):
hole_options = []
h = pm.row_action_hole[action]
if h != pm.num_holes:
hole_options.append( (h,pm.row_action_option[action]) )
h = pm.row_memory_hole[action]
if h != pm.num_holes:
hole_options.append( (h,pm.row_memory_option[action]) )
hole = row_action_hole[action]
if hole != all_holes.num_holes:
hole_options.append( (hole,row_action_option[action]) )
hole = row_memory_hole[action]
if hole != all_holes.num_holes:
hole_options.append( (hole,row_memory_option[action]) )
choice_to_hole_options.append(hole_options)

return all_holes, choice_to_hole_options
Expand Down

0 comments on commit 8efef06

Please sign in to comment.