diff --git a/paynt/quotient/fsc.py b/paynt/quotient/fsc.py index 3452ebce..40d21a55 100644 --- a/paynt/quotient/fsc.py +++ b/paynt/quotient/fsc.py @@ -60,6 +60,37 @@ def from_json(cls, json): fsc.update_function = json["update_function"] return fsc + def reorder_nodes(self, node_old_to_new): + action_function = [None for node in range(self.num_nodes)] + update_function = [None for node in range(self.num_nodes)] + for node_old,node_new in enumerate(node_old_to_new): + action_function[node_new] = self.action_function[node_old] + update_function[node_new] = [node_old_to_new[node] for node in self.update_function[node_old]] + self.action_function = action_function + self.update_function = update_function + + def reorder_actions(self, action_labels): + for node in range(self.num_nodes): + for obs in range(self.num_observations): + if self.is_deterministic: + action = self.action_function[node][obs] + self.action_function[node][obs] = action_labels.index(self.action_labels[action]) + else: + action_function = {} + for action,prob in self.action_function[node][obs].items(): + action_function[action_labels.index(self.action_labels[action])] = prob + self.action_function[node][obs] = action_function + self.action_labels = action_labels.copy() + + def make_stochastic(self): + if not self.is_deterministic: + return + for node in range(self.num_nodes): + for obs in range(self.num_observations): + self.action_function[node][obs] = {self.action_function[node][obs] : 1.0} + self.update_function[node][obs] = {self.update_function[node][obs] : 1.0} + self.is_deterministic = False + def check_action_function(self, observation_to_actions): assert len(self.action_function) == self.num_nodes, "FSC action function is not defined for all memory nodes" for node in range(self.num_nodes): @@ -114,11 +145,12 @@ def fill_zero_updates(self): for node in range(self.num_nodes): self.update_function[node] = [0] * self.num_observations - # this fixes FSCs contructed from not fully unfolded quotients - # this can only be used when the current state of the FSC is correct def fill_implicit_actions_and_updates(self): + ''' + For an FSC with an irregular memory model, make action and updates explicit. + ''' for node in range(self.num_nodes): for obs in range(self.num_observations): - if self.action_function[node][obs] == None: + if self.action_function[node][obs] is None: self.action_function[node][obs] = self.action_function[0][obs] self.update_function[node][obs] = self.update_function[0][obs] diff --git a/paynt/quotient/pomdp.py b/paynt/quotient/pomdp.py index c725fe23..d086ca3c 100644 --- a/paynt/quotient/pomdp.py +++ b/paynt/quotient/pomdp.py @@ -719,21 +719,18 @@ def assignment_to_fsc(self, assignment): # convert hole assignment to FSC for obs,holes in enumerate(self.observation_action_holes): - for memory,hole in enumerate(holes): + for node,hole in enumerate(holes): option = assignment.hole_options(hole)[0] action_label = self.action_labels_at_observation[obs][option] action = action_label_indices[action_label] - fsc.action_function[memory][obs] = action + fsc.action_function[node][obs] = action for obs,holes in enumerate(self.observation_memory_holes): - for memory,hole in enumerate(holes): + for node,hole in enumerate(holes): option = assignment.hole_options(hole)[0] - fsc.update_function[memory][obs] = option + fsc.update_function[node][obs] = option - # fixing the FSC for not fully unrolled quotients fsc.fill_implicit_actions_and_updates() - fsc.check(observation_to_actions) - return fsc diff --git a/payntbind/src/synthesis/translation/choiceTransformation.cpp b/payntbind/src/synthesis/translation/choiceTransformation.cpp index d7715cbd..ac844c52 100644 --- a/payntbind/src/synthesis/translation/choiceTransformation.cpp +++ b/payntbind/src/synthesis/translation/choiceTransformation.cpp @@ -201,7 +201,7 @@ std::pair>,std::vector translated_to_original_choice; std::vector translated_to_original_choice_label; std::vector row_groups_new; - storm::storage::BitVector action_exists(num_actions,false); + std::vector action_to_choice(num_actions); for(uint64_t state = 0; state < num_states; ++state) { row_groups_new.push_back(translated_to_original_choice.size()); if(not state_mask[state]) { @@ -214,23 +214,27 @@ std::pair>,std::vector