diff --git a/paynt/quotient/pomdp.py b/paynt/quotient/pomdp.py index f4895ae1..b929089e 100644 --- a/paynt/quotient/pomdp.py +++ b/paynt/quotient/pomdp.py @@ -8,6 +8,7 @@ import math import re +import collections import logging logger = logging.getLogger(__name__) @@ -100,8 +101,7 @@ def __init__(self, pomdp, specification, decpomdp_manager=None): obs = state_obs[state] if self.action_labels_at_observation[obs] != []: continue - actions = self.pomdp.get_nr_available_actions(state) - for offset in range(actions): + for offset in range(self.actions_at_observation[obs]): choice = self.pomdp.get_choice_index(state,offset) labels = self.pomdp.choice_labeling.get_labels_of_choice(choice) assert len(labels) <= 1, "expected at most 1 label" @@ -780,3 +780,20 @@ def compute_qvalues(self, assignment): state_memory_value_total[state][memory] = value return state_memory_value_total + + + def next_belief(self, belief, action_label, next_obs): + any_belief_state = list(belief.keys())[0] + obs = self.pomdp.observations[any_belief_state] + action = self.action_labels_at_observation[obs].index(action_label) + new_belief = collections.defaultdict(float) + ndi = self.pomdp.nondeterministic_choice_indices.copy() + for state,state_prob in belief.items(): + choice = self.pomdp.get_choice_index(state,action) + for entry in self.pomdp.transition_matrix.get_row(choice): + next_state = entry.column + if self.pomdp.observations[next_state] == next_obs: + new_belief[next_state] += state_prob * entry.value() + prob_sum = sum(new_belief.values()) + new_belief = {state:prob/prob_sum for state,prob in new_belief.items()} + return new_belief