diff --git a/src/translate/fact_groups.py b/src/translate/fact_groups.py index daec4a587..884704722 100644 --- a/src/translate/fact_groups.py +++ b/src/translate/fact_groups.py @@ -68,7 +68,11 @@ def _update_top(self): self.groups_by_size[len(candidate)].append(candidate) self.max_size -= 1 -def choose_groups(groups, reachable_facts): +def choose_groups(groups, reachable_facts, negative_in_goal): + if negative_in_goal: + # we remove atoms that occur negatively in the goal from the groups to + # enforce them to be encoded with a binary variable. + groups = [set(group) - negative_in_goal for group in groups] queue = GroupCoverQueue(groups) uncovered_facts = reachable_facts.copy() result = [] @@ -107,7 +111,8 @@ def sort_groups(groups): return sorted(sorted(group) for group in groups) def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal], - reachable_action_params: Dict[pddl.Action, List[str]]) -> Tuple[ + reachable_action_params: Dict[pddl.Action, List[str]], + negative_in_goal: Set[pddl.Atom]) -> Tuple[ List[List[pddl.Atom]], # groups # -> all selected mutex groups plus singleton groups for uncovered facts List[List[pddl.Atom]], # mutex_groups @@ -128,7 +133,7 @@ def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal], with timers.timing("Collecting mutex groups"): mutex_groups = collect_all_mutex_groups(groups, atoms) with timers.timing("Choosing groups", block=True): - groups = choose_groups(groups, atoms) + groups = choose_groups(groups, atoms, negative_in_goal) groups = sort_groups(groups) with timers.timing("Building translation key"): translation_key = build_translation_key(groups) diff --git a/src/translate/translate.py b/src/translate/translate.py index d2652f0fe..e060e8a59 100755 --- a/src/translate/translate.py +++ b/src/translate/translate.py @@ -555,12 +555,15 @@ def pddl_to_sas(task): elif goal_list is None: return unsolvable_sas_task("Trivially false goal") + negative_in_goal = set() for item in goal_list: assert isinstance(item, pddl.Literal) + if item.negated: + negative_in_goal.add(item.negate()) with timers.timing("Computing fact groups", block=True): groups, mutex_groups, translation_key = fact_groups.compute_groups( - task, atoms, reachable_action_params) + task, atoms, reachable_action_params, negative_in_goal) with timers.timing("Building STRIPS to SAS dictionary"): ranges, strips_to_sas = strips_to_sas_dictionary(