diff --git a/paynt/quotient/fsc.py b/paynt/quotient/fsc.py new file mode 100644 index 000000000..0502328d4 --- /dev/null +++ b/paynt/quotient/fsc.py @@ -0,0 +1,113 @@ +import json + +import logging +logger = logging.getLogger(__name__) + + +class FSC: + ''' + Class for encoding an FSC having + - a fixed number of nodes + - action selection is either: + + deterministic: gamma: NxZ -> Act, or + + randomized: gamma: NxZ -> Distr(Act), where gamma(n,z) is a dictionary of pairs (action,probability) + - deterministic posterior-unaware memory update delta: NxZ -> N + ''' + + def __init__(self, num_nodes, num_observations, is_deterministic=False): + self.num_nodes = num_nodes + self.num_observations = num_observations + self.is_deterministic = is_deterministic + + self.action_function = [ [None]*num_observations for _ in range(num_nodes) ] + self.update_function = [ [None]*num_observations for _ in range(num_nodes) ] + + self.observation_labels = None + self.action_labels = None + + def __str__(self): + return json.dumps(self.to_json(), indent=4) + + def action_function_signature(self): + if self.is_deterministic: + return " NxZ -> Act" + else: + return " NxZ -> Distr(Act)" + + def to_json(self): + json = {} + json["num_nodes"] = self.num_nodes + json["num_observations"] = self.num_observations + json["__comment_action_function"] = "action function has signature {}".format(self.action_function_signature()) + json["__comment_update_function"] = "update function has signature NxZ -> N" + + json["action_function"] = self.action_function + json["update_function"] = self.update_function + + if self.action_labels is not None: + json["action_labels"] = self.action_labels + if self.observation_labels is not None: + json["observation_labels"] = self.observation_labels + + return json + + @classmethod + def from_json(cls, json): + num_nodes = json["num_nodes"] + num_observations = json["num_observations"] + fsc = FSC(num_nodes,num_observations) + fsc.action_function = json["action_function"] + fsc.update_function = json["update_function"] + return fsc + + def check_action_function(self, observation_to_actions): + assert len(self.action_function) == self.num_nodes, "FSC action function is not defined for all memory nodes" + for node in range(self.num_nodes): + assert len(self.action_function[node]) == self.num_observations, \ + "in memory node {}, FSC action function is not defined for all observations".format(node) + for obs in range(self.num_observations): + if self.is_deterministic: + action = self.action_function[node][obs] + assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) + else: + for action,_ in self.action_function[node][obs].items(): + assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) + + def check_update_function(self): + assert len(self.update_function) == self.num_nodes, "FSC update function is not defined for all memory nodes" + for node in range(self.num_nodes): + assert len(self.update_function[node]) == self.num_observations, \ + "in memory node {}, FSC update function is not defined for all observations".format(node) + for obs in range(self.num_observations): + update = self.update_function[node][obs] + assert 0 <= update and update < self.num_nodes, "invalid FSC memory update {}".format(update) + + def check(self, observation_to_actions): + ''' Check whether fields of FSC have been initialized appropriately. ''' + assert self.num_nodes > 0, "FSC must have at least 1 node" + self.check_action_function(observation_to_actions) + self.check_update_function() + + def fill_trivial_actions(self, observation_to_actions): + ''' For each observation with 1 available action, set gamma(n,z) to that action. ''' + for obs,actions in enumerate(observation_to_actions): + if len(actions)>1: + continue + action = actions[0] + if not self.is_deterministic: + action = {action:1} + for node in range(self.num_nodes): + self.action_function[node][obs] = action + + def fill_trivial_updates(self, observation_to_actions): + ''' For each observation with 1 available action, set delta(n,z) to n. ''' + for obs,actions in enumerate(observation_to_actions): + if len(actions)>1: + continue + for node in range(self.num_nodes): + self.update_function[node][obs] = node + + def fill_zero_updates(self): + for node in range(self.num_nodes): + self.update_function[node] = [0] * self.num_observations + diff --git a/paynt/quotient/pomdp.py b/paynt/quotient/pomdp.py index 75f4871cd..f3a4c0e6a 100644 --- a/paynt/quotient/pomdp.py +++ b/paynt/quotient/pomdp.py @@ -4,6 +4,7 @@ import paynt.family.family import paynt.quotient.quotient +import paynt.quotient.fsc from .models import MarkovChain,MDP,DTMC @@ -21,6 +22,9 @@ class PomdpQuotient(paynt.quotient.quotient.Quotient): # if True, posterior-aware unfolding will be applied posterior_aware = False + # label associated with un-labelled choices + EMPTY_LABEL = "__no_label__" + def __init__(self, pomdp, specification, decpomdp_manager=None): super().__init__(specification = specification) @@ -70,7 +74,6 @@ def __init__(self, pomdp, specification, decpomdp_manager=None): if self.pomdp.has_observation_valuations(): ov = self.pomdp.observation_valuations self.observation_labels = [ov.get_string(obs) for obs in range(self.observations)] - self.observation_labels = [self.simplify_label(label) for label in self.observation_labels] else: if decpomdp_manager is None: self.observation_labels = list(range(self.observations)) @@ -102,7 +105,12 @@ def __init__(self, pomdp, specification, decpomdp_manager=None): for offset in range(actions): choice = self.pomdp.get_choice_index(state,offset) labels = self.pomdp.choice_labeling.get_labels_of_choice(choice) - self.action_labels_at_observation[obs].append(labels) + assert len(labels) <= 1, "expected at most 1 label" + if len(labels) == 0: + label = PomdpQuotient.EMPTY_LABEL + else : + label = list(labels)[0] + self.action_labels_at_observation[obs].append(label) # mark perfect observations self.observation_states = [0 for obs in range(self.observations)] @@ -153,22 +161,6 @@ def decode_hole_name(self, name): break return (is_action_hole, observation, memory) - def simplify_label(self,label): - label = re.sub(r"\s+", "", label) - label = label[1:-1] - - output = "["; - first = True - for p in label.split("&"): - if not p.endswith("=0"): - if first: - first = False - else: - output += " & " - output += p - output += "]" - return output - def set_manager_memory_vector(self): for obs in range(self.observations): mem = self.observation_memory_size[obs] @@ -671,8 +663,10 @@ def policy_size(self, assignment): return size_gamma + size_delta - # constructs pomdp from the quotient MDP, used for computing POMDP abstraction bounds def get_family_pomdp(self, mdp): + ''' + Constructs POMDP from the quotient MDP. Used for computing POMDP abstraction bounds. + ''' no_obs = self.pomdp.nr_observations tm = mdp.model.transition_matrix components = stormpy.storage.SparseModelComponents(tm, mdp.model.labeling, mdp.model.reward_models) @@ -709,3 +703,42 @@ def get_family_pomdp(self, mdp): pomdp = stormpy.pomdp.make_canonic(pomdp) return pomdp + + + def assignment_to_fsc(self, assignment): + assert assignment.size == 1, "expected family of size 1" + num_nodes = max(self.observation_memory_size) + fsc = paynt.quotient.fsc.FSC(num_nodes, self.observations, is_deterministic=True) + fsc.observation_labels = self.observation_labels + + # collect action labels + action_labels = set() + for labels in self.action_labels_at_observation: + action_labels.update(labels) + action_labels = list(action_labels) + fsc.action_labels = action_labels + + # map observations to unique indices of available actions + action_label_indices = {label:index for index,label in enumerate(action_labels)} + observation_to_actions = [[] for obs in range(self.observations)] + for obs,action_labels in enumerate(self.action_labels_at_observation): + observation_to_actions[obs] = [action_label_indices[label] for label in action_labels] + + fsc.fill_trivial_actions(observation_to_actions) + fsc.fill_zero_updates() + + # convert hole assignment to FSC + for obs,holes in enumerate(self.observation_action_holes): + for memory,hole in enumerate(holes): + option = assignment.hole_options(hole)[0] + action_label = self.action_labels_at_observation[obs][option] + action = action_label_indices[action_label] + fsc.action_function[memory][obs] = action + for obs,holes in enumerate(self.observation_memory_holes): + for memory,hole in enumerate(holes): + option = assignment.hole_options(hole)[0] + fsc.update_function[memory][obs] = option + + fsc.check(observation_to_actions) + + return fsc diff --git a/paynt/quotient/pomdp_family.py b/paynt/quotient/pomdp_family.py index a3b5baaae..2007f2b6b 100644 --- a/paynt/quotient/pomdp_family.py +++ b/paynt/quotient/pomdp_family.py @@ -6,106 +6,10 @@ import paynt.quotient.mdp_family import paynt.verification.property -import json import logging logger = logging.getLogger(__name__) - -class FSC: - ''' - Class for encoding an FSC having - - a fixed number of nodes - - action selection is either: - + deterministic: gamma: NxZ -> Act, or - + randomized: gamma: NxZ -> Distr(Act), where gamma(n,z) is a dictionary of pairs (action,probability) - - deterministic posterior-unaware memory update delta: NxZ -> N - ''' - - def __init__(self, num_nodes, num_observations, is_deterministic=False): - self.num_nodes = num_nodes - self.num_observations = num_observations - self.is_deterministic = is_deterministic - - self.action_function = [ [None]*num_observations for _ in range(num_nodes) ] - self.update_function = [ [None]*num_observations for _ in range(num_nodes) ] - - def __str__(self): - return json.dumps(self.to_json(), indent=4) - - def action_function_signature(self): - if self.is_deterministic: - return " NxZ -> Act" - else: - return " NxZ -> Distr(Act)" - - def to_json(self): - json = {} - json["num_nodes"] = self.num_nodes - json["num_observations"] = self.num_observations - json["__comment_action_function"] = "action function has signature {}".format(self.action_function_signature()) - json["__comment_update_function"] = "update function has signature NxZ -> N" - json["action_function"] = self.action_function - json["update_function"] = self.update_function - return json - - @classmethod - def from_json(cls, json): - num_nodes = json["num_nodes"] - num_observations = json["num_observations"] - fsc = FSC(num_nodes,num_observations) - fsc.action_function = json["action_function"] - fsc.update_function = json["update_function"] - return fsc - - def check_action_function(self, observation_to_actions): - assert len(self.action_function) == self.num_nodes, "FSC action function is not defined for all memory nodes" - for node in range(self.num_nodes): - assert len(self.action_function[node]) == self.num_observations, \ - "in memory node {}, FSC action function is not defined for all observations".format(node) - for obs in range(self.num_observations): - if self.is_deterministic: - action = self.action_function[node][obs] - assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) - else: - for action,_ in self.action_function[node][obs].items(): - assert action in observation_to_actions[obs], "in observation {} FSC chooses invalid action {}".format(obs,action) - - def check_update_function(self): - assert len(self.update_function) == self.num_nodes, "FSC update function is not defined for all memory nodes" - for node in range(self.num_nodes): - assert len(self.update_function[node]) == self.num_observations, \ - "in memory node {}, FSC update function is not defined for all observations".format(node) - for obs in range(self.num_observations): - update = self.update_function[node][obs] - assert 0 <= update and update < self.num_nodes, "invalid FSC memory update {}".format(update) - - def check(self, observation_to_actions): - ''' Check whether fields of FSC have been initialized appropriately. ''' - assert self.num_nodes > 0, "FSC must have at least 1 node" - self.check_action_function(observation_to_actions) - self.check_update_function() - - def fill_trivial_actions(self, observation_to_actions): - ''' For each observation with 1 available action, set gamma(n,z) to that action. ''' - for obs,actions in enumerate(observation_to_actions): - if len(actions)>1: - continue - action = actions[0] - if not self.is_deterministic: - action = {action:1} - for node in range(self.num_nodes): - self.action_function[node][obs] = action - - def fill_trivial_updates(self, observation_to_actions): - ''' For each observation with 1 available action, set delta(n,z) to n. ''' - for obs,actions in enumerate(observation_to_actions): - if len(actions)>1: - continue - for node in range(self.num_nodes): - self.update_function[node][obs] = node - - class SubPomdp: ''' Simple container for a (sub-)POMDP created from the quotient. @@ -255,7 +159,7 @@ def compute_qvalues_for_product_submdp(self, product_submdp : paynt.quotient.mod return state_memory_action_to_value - def translate_path_to_trace(self, dtmc_sketch, dtmc, path): + def translate_path_to_trace(self, dtmc, path): invalid_choice = self.quotient_mdp.nr_choices trace = [] for dtmc_state in path: @@ -320,7 +224,7 @@ def compute_witnessing_traces(self, dtmc_sketch, satisfying_assignment, num_trac if not success: break path.append(simulator.get_current_state()) - trace = self.translate_path_to_trace(dtmc_sketch,dtmc,path) + trace = self.translate_path_to_trace(dtmc,path) traces.append(trace) else: # target is reachable: use KSP @@ -336,6 +240,6 @@ def compute_witnessing_traces(self, dtmc_sketch, satisfying_assignment, num_trac for k in range(1,num_traces+1): path = shortest_paths_generator.get_path_as_list(k) path.reverse() - trace = self.translate_path_to_trace(dtmc_sketch,dtmc,path) + trace = self.translate_path_to_trace(dtmc,path) traces.append(trace) return traces diff --git a/paynt/quotient/storm_pomdp_control.py b/paynt/quotient/storm_pomdp_control.py index 777370084..80b4ab84a 100644 --- a/paynt/quotient/storm_pomdp_control.py +++ b/paynt/quotient/storm_pomdp_control.py @@ -393,13 +393,13 @@ def parse_storm_result(self, quotient): for label in state.labels: # observation based on prism observables if '[' in label: - simplified_label = self.quotient.simplify_label(label) - observation = self.quotient.observation_labels.index(simplified_label) + observation = self.quotient.observation_labels.index(label) index = -1 + choice_label = list(get_choice_label(state.id))[0] for i in range(len(quotient.action_labels_at_observation[int(observation)])): - if list(get_choice_label(state.id))[0] in quotient.action_labels_at_observation[int(observation)][i]: + if choice_label == quotient.action_labels_at_observation[int(observation)][i]: index = i break @@ -413,9 +413,9 @@ def parse_storm_result(self, quotient): _, observation = label.split('_') index = -1 - + choice_label = list(get_choice_label(state.id))[0] for i in range(len(quotient.action_labels_at_observation[int(observation)])): - if list(get_choice_label(state.id))[0] in quotient.action_labels_at_observation[int(observation)][i]: + if choice_label == quotient.action_labels_at_observation[int(observation)][i]: index = i break