diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 2aaced7e9..c286b0241 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -15,7 +15,7 @@ env: cplex_DIR: D:\a\cplex CPLEX_URL: "${{ secrets.CPLEX2211_WINDOWS_URL }}" - ZLIB_URL: "https://www.zlib.net/zlib13.zip" + ZLIB_URL: "https://www.zlib.net/zlib131.zip" jobs: @@ -53,7 +53,7 @@ jobs: cd zlib echo "Set up zlib include directory" - move ../zlib-1.3 include + move ../zlib-1.3.1 include echo "Compile zlib library" cd include diff --git a/README.md b/README.md index 727384f47..80d8bc58c 100644 --- a/README.md +++ b/README.md @@ -313,6 +313,7 @@ contributing, and finally by last name. - 2022-2023 Remo Christen - 2023 Simon Dold - 2023 Claudia S. Grundke +- 2023 Victor Paléologue - 2023 Emanuele Tirendi - 2021-2022 Dominik Drexler - 2016-2020 Cedric Geissmann diff --git a/src/search/landmarks/landmark_factory.h b/src/search/landmarks/landmark_factory.h index a0b78fbad..2258f6bc6 100644 --- a/src/search/landmarks/landmark_factory.h +++ b/src/search/landmarks/landmark_factory.h @@ -32,13 +32,6 @@ class LandmarkFactory { std::shared_ptr compute_lm_graph(const std::shared_ptr &task); - /* - TODO: Currently reasonable orders are not supported for admissible landmark count - heuristics, which is why the heuristic needs to know whether the factory computes - reasonable orders. Once issue383 is dealt with we should be able to use reasonable - orders for admissible heuristics and this method can be removed. - */ - virtual bool computes_reasonable_orders() const = 0; virtual bool supports_conditional_effects() const = 0; bool achievers_are_calculated() const { diff --git a/src/search/landmarks/landmark_factory_h_m.cc b/src/search/landmarks/landmark_factory_h_m.cc index f6e92362f..2894b41a3 100644 --- a/src/search/landmarks/landmark_factory_h_m.cc +++ b/src/search/landmarks/landmark_factory_h_m.cc @@ -1008,10 +1008,6 @@ void LandmarkFactoryHM::generate_landmarks( postprocess(task_proxy); } -bool LandmarkFactoryHM::computes_reasonable_orders() const { - return false; -} - bool LandmarkFactoryHM::supports_conditional_effects() const { return false; } diff --git a/src/search/landmarks/landmark_factory_h_m.h b/src/search/landmarks/landmark_factory_h_m.h index 7d6f0737f..d8a411faa 100644 --- a/src/search/landmarks/landmark_factory_h_m.h +++ b/src/search/landmarks/landmark_factory_h_m.h @@ -140,7 +140,6 @@ class LandmarkFactoryHM : public LandmarkFactory { public: explicit LandmarkFactoryHM(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/landmarks/landmark_factory_merged.cc b/src/search/landmarks/landmark_factory_merged.cc index 61e6bb60a..f67fb9a3b 100644 --- a/src/search/landmarks/landmark_factory_merged.cc +++ b/src/search/landmarks/landmark_factory_merged.cc @@ -132,15 +132,6 @@ void LandmarkFactoryMerged::postprocess() { lm_graph->set_landmark_ids(); } -bool LandmarkFactoryMerged::computes_reasonable_orders() const { - for (const shared_ptr &lm_factory : lm_factories) { - if (lm_factory->computes_reasonable_orders()) { - return true; - } - } - return false; -} - bool LandmarkFactoryMerged::supports_conditional_effects() const { for (const shared_ptr &lm_factory : lm_factories) { if (!lm_factory->supports_conditional_effects()) { diff --git a/src/search/landmarks/landmark_factory_merged.h b/src/search/landmarks/landmark_factory_merged.h index 9d332f91c..06556d9e1 100644 --- a/src/search/landmarks/landmark_factory_merged.h +++ b/src/search/landmarks/landmark_factory_merged.h @@ -15,7 +15,6 @@ class LandmarkFactoryMerged : public LandmarkFactory { public: explicit LandmarkFactoryMerged(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/landmarks/landmark_factory_reasonable_orders_hps.cc b/src/search/landmarks/landmark_factory_reasonable_orders_hps.cc index 9d6a16baf..3469c0c1a 100644 --- a/src/search/landmarks/landmark_factory_reasonable_orders_hps.cc +++ b/src/search/landmarks/landmark_factory_reasonable_orders_hps.cc @@ -349,10 +349,6 @@ bool LandmarkFactoryReasonableOrdersHPS::effect_always_happens( return eff.empty(); } -bool LandmarkFactoryReasonableOrdersHPS::computes_reasonable_orders() const { - return true; -} - bool LandmarkFactoryReasonableOrdersHPS::supports_conditional_effects() const { return lm_factory->supports_conditional_effects(); } diff --git a/src/search/landmarks/landmark_factory_reasonable_orders_hps.h b/src/search/landmarks/landmark_factory_reasonable_orders_hps.h index 91a7f2f99..01758afd4 100644 --- a/src/search/landmarks/landmark_factory_reasonable_orders_hps.h +++ b/src/search/landmarks/landmark_factory_reasonable_orders_hps.h @@ -21,7 +21,6 @@ class LandmarkFactoryReasonableOrdersHPS : public LandmarkFactory { public: LandmarkFactoryReasonableOrdersHPS(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/landmarks/landmark_factory_rpg_exhaust.cc b/src/search/landmarks/landmark_factory_rpg_exhaust.cc index a67a91d2f..51e5d7957 100644 --- a/src/search/landmarks/landmark_factory_rpg_exhaust.cc +++ b/src/search/landmarks/landmark_factory_rpg_exhaust.cc @@ -53,10 +53,6 @@ void LandmarkFactoryRpgExhaust::generate_relaxed_landmarks( } } -bool LandmarkFactoryRpgExhaust::computes_reasonable_orders() const { - return false; -} - bool LandmarkFactoryRpgExhaust::supports_conditional_effects() const { return false; } diff --git a/src/search/landmarks/landmark_factory_rpg_exhaust.h b/src/search/landmarks/landmark_factory_rpg_exhaust.h index 027432ea1..710789356 100644 --- a/src/search/landmarks/landmark_factory_rpg_exhaust.h +++ b/src/search/landmarks/landmark_factory_rpg_exhaust.h @@ -12,7 +12,6 @@ class LandmarkFactoryRpgExhaust : public LandmarkFactoryRelaxation { public: explicit LandmarkFactoryRpgExhaust(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/landmarks/landmark_factory_rpg_sasp.cc b/src/search/landmarks/landmark_factory_rpg_sasp.cc index af70a805b..97a5e6b7d 100644 --- a/src/search/landmarks/landmark_factory_rpg_sasp.cc +++ b/src/search/landmarks/landmark_factory_rpg_sasp.cc @@ -627,10 +627,6 @@ void LandmarkFactoryRpgSasp::discard_disjunctive_landmarks() { } } -bool LandmarkFactoryRpgSasp::computes_reasonable_orders() const { - return false; -} - bool LandmarkFactoryRpgSasp::supports_conditional_effects() const { return true; } diff --git a/src/search/landmarks/landmark_factory_rpg_sasp.h b/src/search/landmarks/landmark_factory_rpg_sasp.h index 1489743dc..9922974fd 100644 --- a/src/search/landmarks/landmark_factory_rpg_sasp.h +++ b/src/search/landmarks/landmark_factory_rpg_sasp.h @@ -64,7 +64,6 @@ class LandmarkFactoryRpgSasp : public LandmarkFactoryRelaxation { public: explicit LandmarkFactoryRpgSasp(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/landmarks/landmark_factory_zhu_givan.cc b/src/search/landmarks/landmark_factory_zhu_givan.cc index 111cd54e8..349fac895 100644 --- a/src/search/landmarks/landmark_factory_zhu_givan.cc +++ b/src/search/landmarks/landmark_factory_zhu_givan.cc @@ -300,10 +300,6 @@ void LandmarkFactoryZhuGivan::add_operator_to_triggers(const OperatorProxy &op) triggers[lm.var][lm.value].push_back(op_or_axiom_id); } -bool LandmarkFactoryZhuGivan::computes_reasonable_orders() const { - return false; -} - bool LandmarkFactoryZhuGivan::supports_conditional_effects() const { return true; } diff --git a/src/search/landmarks/landmark_factory_zhu_givan.h b/src/search/landmarks/landmark_factory_zhu_givan.h index 1b49f2cf0..dbc1569a3 100644 --- a/src/search/landmarks/landmark_factory_zhu_givan.h +++ b/src/search/landmarks/landmark_factory_zhu_givan.h @@ -77,7 +77,6 @@ class LandmarkFactoryZhuGivan : public LandmarkFactoryRelaxation { public: explicit LandmarkFactoryZhuGivan(const plugins::Options &opts); - virtual bool computes_reasonable_orders() const override; virtual bool supports_conditional_effects() const override; }; } diff --git a/src/search/merge_and_shrink/merge_strategy_factory_sccs.cc b/src/search/merge_and_shrink/merge_strategy_factory_sccs.cc index 74550965e..2d9dd9ef9 100644 --- a/src/search/merge_and_shrink/merge_strategy_factory_sccs.cc +++ b/src/search/merge_and_shrink/merge_strategy_factory_sccs.cc @@ -74,30 +74,23 @@ unique_ptr MergeStrategyFactorySCCs::compute_merge_strategy( break; } - /* - Compute the indices at which the merged SCCs can be found when all - SCCs have been merged. - */ - int index = num_vars - 1; - log << "SCCs of the causal graph:" << endl; + if (log.is_at_least_normal()) { + log << "SCCs of the causal graph:" << endl; + } vector> non_singleton_cg_sccs; - vector indices_of_merged_sccs; - indices_of_merged_sccs.reserve(sccs.size()); for (const vector &scc : sccs) { - log << scc << endl; + if (log.is_at_least_normal()) { + log << scc << endl; + } int scc_size = scc.size(); - if (scc_size == 1) { - indices_of_merged_sccs.push_back(scc.front()); - } else { - index += scc_size - 1; - indices_of_merged_sccs.push_back(index); + if (scc_size != 1) { non_singleton_cg_sccs.push_back(scc); } } - if (sccs.size() == 1) { + if (log.is_at_least_normal() && sccs.size() == 1) { log << "Only one single SCC" << endl; } - if (static_cast(sccs.size()) == num_vars) { + if (log.is_at_least_normal() && static_cast(sccs.size()) == num_vars) { log << "Only singleton SCCs" << endl; assert(non_singleton_cg_sccs.empty()); } @@ -111,8 +104,7 @@ unique_ptr MergeStrategyFactorySCCs::compute_merge_strategy( task_proxy, merge_tree_factory, merge_selector, - move(non_singleton_cg_sccs), - move(indices_of_merged_sccs)); + move(non_singleton_cg_sccs)); } bool MergeStrategyFactorySCCs::requires_init_distances() const { diff --git a/src/search/merge_and_shrink/merge_strategy_sccs.cc b/src/search/merge_and_shrink/merge_strategy_sccs.cc index 7cd4bffb8..a5a7ae2e2 100644 --- a/src/search/merge_and_shrink/merge_strategy_sccs.cc +++ b/src/search/merge_and_shrink/merge_strategy_sccs.cc @@ -18,14 +18,12 @@ MergeStrategySCCs::MergeStrategySCCs( const TaskProxy &task_proxy, const shared_ptr &merge_tree_factory, const shared_ptr &merge_selector, - vector> non_singleton_cg_sccs, - vector indices_of_merged_sccs) + vector> &&non_singleton_cg_sccs) : MergeStrategy(fts), task_proxy(task_proxy), merge_tree_factory(merge_tree_factory), merge_selector(merge_selector), non_singleton_cg_sccs(move(non_singleton_cg_sccs)), - indices_of_merged_sccs(move(indices_of_merged_sccs)), current_merge_tree(nullptr) { } @@ -33,31 +31,40 @@ MergeStrategySCCs::~MergeStrategySCCs() { } pair MergeStrategySCCs::get_next() { - // We did not already start merging an SCC/all finished SCCs, so we - // do not have a current set of indices we want to finish merging. if (current_ts_indices.empty()) { - // Get the next indices we need to merge + /* + We are currently not dealing with merging all factors of an SCC, so + we need to either get the next one or allow merging any existing + factors of the FTS if there is no SCC left. + */ if (non_singleton_cg_sccs.empty()) { - assert(indices_of_merged_sccs.size() > 1); - current_ts_indices = move(indices_of_merged_sccs); + // We are done dealing with all SCCs, allow merging any factors. + current_ts_indices.reserve(fts.get_num_active_entries()); + for (int ts_index: fts) { + current_ts_indices.push_back(ts_index); + } } else { + /* + There is another SCC we have to deal with. Store its factors so + that we merge them over the next iterations. + */ vector ¤t_scc = non_singleton_cg_sccs.front(); assert(current_scc.size() > 1); current_ts_indices = move(current_scc); non_singleton_cg_sccs.erase(non_singleton_cg_sccs.begin()); } - // If using a merge tree factory, compute a merge tree for this set + // If using a merge tree factory, compute a merge tree for this set. if (merge_tree_factory) { current_merge_tree = merge_tree_factory->compute_merge_tree( task_proxy, fts, current_ts_indices); } } else { - // Add the most recent merge to the current indices set + // Add the most recent product to the current index set. current_ts_indices.push_back(fts.get_size() - 1); } - // Select the next merge for the current set of indices, either using the + // Select the next merge from the current index set, either using the // tree or the selector. pair next_pair; int merged_ts_index = fts.get_size(); @@ -72,7 +79,7 @@ pair MergeStrategySCCs::get_next() { next_pair = merge_selector->select_merge(fts, current_ts_indices); } - // Remove the two merged indices from the current set of indices. + // Remove the two merged indices from the current index set. for (vector::iterator it = current_ts_indices.begin(); it != current_ts_indices.end();) { if (*it == next_pair.first || *it == next_pair.second) { diff --git a/src/search/merge_and_shrink/merge_strategy_sccs.h b/src/search/merge_and_shrink/merge_strategy_sccs.h index d1c6abc4d..8fac9390e 100644 --- a/src/search/merge_and_shrink/merge_strategy_sccs.h +++ b/src/search/merge_and_shrink/merge_strategy_sccs.h @@ -17,9 +17,7 @@ class MergeStrategySCCs : public MergeStrategy { std::shared_ptr merge_tree_factory; std::shared_ptr merge_selector; std::vector> non_singleton_cg_sccs; - std::vector indices_of_merged_sccs; - // Active "merge strategies" while merging a set of indices std::unique_ptr current_merge_tree; std::vector current_ts_indices; public: @@ -28,8 +26,7 @@ class MergeStrategySCCs : public MergeStrategy { const TaskProxy &task_proxy, const std::shared_ptr &merge_tree_factory, const std::shared_ptr &merge_selector, - std::vector> non_singleton_cg_sccs, - std::vector indices_of_merged_sccs); + std::vector> &&non_singleton_cg_sccs); virtual ~MergeStrategySCCs() override; virtual std::pair get_next() override; }; diff --git a/src/search/planner.cc b/src/search/planner.cc index 47113af6d..afe060290 100644 --- a/src/search/planner.cc +++ b/src/search/planner.cc @@ -49,6 +49,5 @@ int main(int argc, const char **argv) { } else if (search_algorithm->get_status() == UNSOLVABLE) { exitcode = ExitCode::SEARCH_UNSOLVABLE; } - utils::report_exit_code_reentrant(exitcode); - return static_cast(exitcode); + exit_with(exitcode); } diff --git a/src/search/utils/system.cc b/src/search/utils/system.cc index f2f54b6f4..001f68542 100644 --- a/src/search/utils/system.cc +++ b/src/search/utils/system.cc @@ -49,11 +49,9 @@ void exit_with(ExitCode exitcode) { exit(static_cast(exitcode)); } -void exit_after_receiving_signal(ExitCode exitcode) { - /* - In signal handlers, we have to use the "safe function" _Exit() rather - than the unsafe function exit(). - */ +void exit_with_reentrant(ExitCode exitcode) { + /* In signal handlers or when we run out of memory, we have to use the + "safe function" _Exit() rather than the unsafe function exit(). */ report_exit_code_reentrant(exitcode); _Exit(static_cast(exitcode)); } diff --git a/src/search/utils/system.h b/src/search/utils/system.h index a5ec692ae..1f23ef19c 100644 --- a/src/search/utils/system.h +++ b/src/search/utils/system.h @@ -54,7 +54,7 @@ enum class ExitCode { }; NO_RETURN extern void exit_with(ExitCode returncode); -NO_RETURN extern void exit_after_receiving_signal(ExitCode returncode); +NO_RETURN extern void exit_with_reentrant(ExitCode returncode); int get_peak_memory_in_kb(); const char *get_exit_code_message_reentrant(ExitCode exitcode); diff --git a/src/search/utils/system_unix.cc b/src/search/utils/system_unix.cc index 6752d08b6..9b4709c54 100644 --- a/src/search/utils/system_unix.cc +++ b/src/search/utils/system_unix.cc @@ -164,7 +164,7 @@ static void out_of_memory_handler() { memory for the stack of the signal handler and raising a signal here. */ write_reentrant_str(STDOUT_FILENO, "Failed to allocate memory.\n"); - exit_with(ExitCode::SEARCH_OUT_OF_MEMORY); + exit_with_reentrant(ExitCode::SEARCH_OUT_OF_MEMORY); } static void signal_handler(int signal_number) { @@ -173,7 +173,7 @@ static void signal_handler(int signal_number) { write_reentrant_int(STDOUT_FILENO, signal_number); write_reentrant_str(STDOUT_FILENO, " -- exiting\n"); if (signal_number == SIGXCPU) { - exit_after_receiving_signal(ExitCode::SEARCH_OUT_OF_TIME); + exit_with_reentrant(ExitCode::SEARCH_OUT_OF_TIME); } raise(signal_number); } diff --git a/src/search/utils/system_windows.cc b/src/search/utils/system_windows.cc index f8b4bc1e5..a244b6846 100644 --- a/src/search/utils/system_windows.cc +++ b/src/search/utils/system_windows.cc @@ -15,7 +15,7 @@ using namespace std; namespace utils { void out_of_memory_handler() { cout << "Failed to allocate memory." << endl; - exit_with(ExitCode::SEARCH_OUT_OF_MEMORY); + exit_with_reentrant(ExitCode::SEARCH_OUT_OF_MEMORY); } void signal_handler(int signal_number) { diff --git a/src/translate/constraints.py b/src/translate/constraints.py index 6190acc0f..a26a2c9e7 100644 --- a/src/translate/constraints.py +++ b/src/translate/constraints.py @@ -1,41 +1,30 @@ import itertools +from typing import Iterable, List, Tuple -class NegativeClause: - # disjunction of inequalities - def __init__(self, parts): +class InequalityDisjunction: + def __init__(self, parts: List[Tuple[str, str]]): self.parts = parts assert len(parts) def __str__(self): - disj = " or ".join(["(%s != %s)" % (v1, v2) - for (v1, v2) in self.parts]) - return "(%s)" % disj - - def is_satisfiable(self): - for part in self.parts: - if part[0] != part[1]: - return True - return False - - def apply_mapping(self, m): - new_parts = [(m.get(v1, v1), m.get(v2, v2)) for (v1, v2) in self.parts] - return NegativeClause(new_parts) + disj = " or ".join([f"({v1} != {v2})" for (v1, v2) in self.parts]) + return f"({disj})" -class Assignment: - def __init__(self, equalities): - self.equalities = tuple(equalities) - # represents a conjunction of expressions ?x = ?y or ?x = d - # with ?x, ?y being variables and d being a domain value +class EqualityConjunction: + def __init__(self, equalities: List[Tuple[str, str]]): + self.equalities = equalities + # A conjunction of expressions x = y, where x,y are either strings + # that denote objects or variables, or ints that denote invariant + # parameters. - self.consistent = None - self.mapping = None - self.eq_classes = None + self._consistent = None + self._representative = None # dictionary + self._eq_classes = None def __str__(self): - conj = " and ".join(["(%s = %s)" % (v1, v2) - for (v1, v2) in self.equalities]) - return "(%s)" % conj + conj = " and ".join([f"({v1} = {v2})" for (v1, v2) in self.equalities]) + return f"({conj})" def _compute_equivalence_classes(self): eq_classes = {} @@ -48,113 +37,141 @@ def _compute_equivalence_classes(self): c1.update(c2) for elem in c2: eq_classes[elem] = c1 - self.eq_classes = eq_classes + self._eq_classes = eq_classes - def _compute_mapping(self): - if not self.eq_classes: + def _compute_representatives(self): + if not self._eq_classes: self._compute_equivalence_classes() - # create mapping: each key is mapped to the smallest - # element in its equivalence class (with objects being - # smaller than variables) - mapping = {} - for eq_class in self.eq_classes.values(): - variables = [item for item in eq_class if item.startswith("?")] - constants = [item for item in eq_class if not item.startswith("?")] - if len(constants) >= 2: - self.consistent = False - self.mapping = None + # Choose a representative for each equivalence class. Objects are + # prioritized over variables and ints, but at most one object per + # equivalence class is allowed (otherwise the conjunction is + # inconsistent). + representative = {} + for eq_class in self._eq_classes.values(): + if next(iter(eq_class)) in representative: + continue # we already processed this equivalence class + variables = [item for item in eq_class if isinstance(item, int) or + item.startswith("?")] + objects = [item for item in eq_class if not isinstance(item, int) + and not item.startswith("?")] + + if len(objects) >= 2: + self._consistent = False + self._representative = None return - if constants: - set_val = constants[0] + if objects: + set_val = objects[0] else: - set_val = min(variables) + set_val = variables[0] for entry in eq_class: - mapping[entry] = set_val - self.consistent = True - self.mapping = mapping + representative[entry] = set_val + self._consistent = True + self._representative = representative def is_consistent(self): - if self.consistent is None: - self._compute_mapping() - return self.consistent + if self._consistent is None: + self._compute_representatives() + return self._consistent - def get_mapping(self): - if self.consistent is None: - self._compute_mapping() - return self.mapping + def get_representative(self): + if self._consistent is None: + self._compute_representatives() + return self._representative class ConstraintSystem: + """A ConstraintSystem stores two parts, both talking about the equality or + inequality of strings and ints (strings representing objects or + variables, ints representing invariant parameters): + - equality_DNFs is a list containing lists of EqualityConjunctions. + Each EqualityConjunction represents an expression of the form + (x1 = y1 and ... and xn = yn). A list of EqualityConjunctions can be + interpreted as a disjunction of such expressions. So + self.equality_DNFs represents a formula of the form "⋀ ⋁ ⋀ (x = y)" + as a list of lists of EqualityConjunctions. + - ineq_disjunctions is a list of InequalityDisjunctions. Each of them + represents a expression of the form (u1 != v1 or ... or um !=i vm). + - not_constant is a list of strings. + + We say that the system is solvable if we can pick from each list of + EqualityConjunctions in equality_DNFs one EquivalenceConjunction such + that the finest equivalence relation induced by all the equivalences in + the conjunctions is + - consistent, i.e. no equivalence class contains more than one object, + - for every disjunction in ineq_disjunctions there is at least one + inequality such that the two terms are in different equivalence + classes. + - every element of not_constant is not in the same equivalence class + as a constant. + We refer to the equivalence relation as the solution of the system.""" + def __init__(self): - self.combinatorial_assignments = [] - self.neg_clauses = [] + self.equality_DNFs = [] + self.ineq_disjunctions = [] + self.not_constant = [] def __str__(self): - combinatorial_assignments = [] - for comb_assignment in self.combinatorial_assignments: - disj = " or ".join([str(assig) for assig in comb_assignment]) + equality_DNFs = [] + for eq_DNF in self.equality_DNFs: + disj = " or ".join([str(eq_conjunction) + for eq_conjunction in eq_DNF]) disj = "(%s)" % disj - combinatorial_assignments.append(disj) - assigs = " and\n".join(combinatorial_assignments) - - neg_clauses = [str(clause) for clause in self.neg_clauses] - neg_clauses = " and ".join(neg_clauses) - return assigs + "(" + neg_clauses + ")" - - def _all_clauses_satisfiable(self, assignment): - mapping = assignment.get_mapping() - for neg_clause in self.neg_clauses: - clause = neg_clause.apply_mapping(mapping) - if not clause.is_satisfiable(): - return False - return True - - def _combine_assignments(self, assignments): - new_equalities = [] - for a in assignments: - new_equalities.extend(a.equalities) - return Assignment(new_equalities) - - def add_assignment(self, assignment): - self.add_assignment_disjunction([assignment]) - - def add_assignment_disjunction(self, assignments): - self.combinatorial_assignments.append(assignments) - - def add_negative_clause(self, negative_clause): - self.neg_clauses.append(negative_clause) - - def combine(self, other): - """Combines two constraint systems to a new system""" - combined = ConstraintSystem() - combined.combinatorial_assignments = (self.combinatorial_assignments + - other.combinatorial_assignments) - combined.neg_clauses = self.neg_clauses + other.neg_clauses - return combined - - def copy(self): - other = ConstraintSystem() - other.combinatorial_assignments = list(self.combinatorial_assignments) - other.neg_clauses = list(self.neg_clauses) - return other - - def dump(self): - print("AssignmentSystem:") - for comb_assignment in self.combinatorial_assignments: - disj = " or ".join([str(assig) for assig in comb_assignment]) - print(" ASS: ", disj) - for neg_clause in self.neg_clauses: - print(" NEG: ", str(neg_clause)) + equality_DNFs.append(disj) + eq_part = " and\n".join(equality_DNFs) + + ineq_disjunctions = [str(clause) for clause in self.ineq_disjunctions] + ineq_part = " and ".join(ineq_disjunctions) + return f"{eq_part} ({ineq_part}) (not constant {self.not_constant}" + + def _combine_equality_conjunctions(self, eq_conjunctions: + Iterable[EqualityConjunction]) -> None: + all_eq = itertools.chain.from_iterable(c.equalities + for c in eq_conjunctions) + return EqualityConjunction(list(all_eq)) + + def add_equality_conjunction(self, eq_conjunction: EqualityConjunction): + self.add_equality_DNF([eq_conjunction]) + + def add_equality_DNF(self, equality_DNF: List[EqualityConjunction]) -> None: + self.equality_DNFs.append(equality_DNF) + + def add_inequality_disjunction(self, ineq_disj: InequalityDisjunction): + self.ineq_disjunctions.append(ineq_disj) + + def add_not_constant(self, not_constant: str) -> None: + self.not_constant.append(not_constant) + + def extend(self, other: "ConstraintSystem") -> None: + self.equality_DNFs.extend(other.equality_DNFs) + self.ineq_disjunctions.extend(other.ineq_disjunctions) + self.not_constant.extend(other.not_constant) def is_solvable(self): - """Check whether the combinatorial assignments include at least - one consistent assignment under which the negative clauses - are satisfiable""" - for assignments in itertools.product(*self.combinatorial_assignments): - combined = self._combine_assignments(assignments) + # cf. top of class for explanation + def inequality_disjunction_ok(ineq_disj, representative): + for inequality in ineq_disj.parts: + a, b = inequality + if representative.get(a, a) != representative.get(b, b): + return True + return False + + for eq_conjunction in itertools.product(*self.equality_DNFs): + combined = self._combine_equality_conjunctions(eq_conjunction) if not combined.is_consistent(): continue - if self._all_clauses_satisfiable(combined): - return True + # check whether with the finest equivalence relation induced by the + # combined equality conjunction there is no element of not_constant + # in the same equivalence class as a constant and that in each + # inequality disjunction there is an inequality where the two terms + # are in different equivalence classes. + representative = combined.get_representative() + if any(not isinstance(representative.get(s, s), int) and + representative.get(s, s)[0] != "?" + for s in self.not_constant): + continue + if any(not inequality_disjunction_ok(d, representative) + for d in self.ineq_disjunctions): + continue + return True return False diff --git a/src/translate/fact_groups.py b/src/translate/fact_groups.py index daec4a587..884704722 100644 --- a/src/translate/fact_groups.py +++ b/src/translate/fact_groups.py @@ -68,7 +68,11 @@ def _update_top(self): self.groups_by_size[len(candidate)].append(candidate) self.max_size -= 1 -def choose_groups(groups, reachable_facts): +def choose_groups(groups, reachable_facts, negative_in_goal): + if negative_in_goal: + # we remove atoms that occur negatively in the goal from the groups to + # enforce them to be encoded with a binary variable. + groups = [set(group) - negative_in_goal for group in groups] queue = GroupCoverQueue(groups) uncovered_facts = reachable_facts.copy() result = [] @@ -107,7 +111,8 @@ def sort_groups(groups): return sorted(sorted(group) for group in groups) def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal], - reachable_action_params: Dict[pddl.Action, List[str]]) -> Tuple[ + reachable_action_params: Dict[pddl.Action, List[str]], + negative_in_goal: Set[pddl.Atom]) -> Tuple[ List[List[pddl.Atom]], # groups # -> all selected mutex groups plus singleton groups for uncovered facts List[List[pddl.Atom]], # mutex_groups @@ -128,7 +133,7 @@ def compute_groups(task: pddl.Task, atoms: Set[pddl.Literal], with timers.timing("Collecting mutex groups"): mutex_groups = collect_all_mutex_groups(groups, atoms) with timers.timing("Choosing groups", block=True): - groups = choose_groups(groups, atoms) + groups = choose_groups(groups, atoms, negative_in_goal) groups = sort_groups(groups) with timers.timing("Building translation key"): translation_key = build_translation_key(groups) diff --git a/src/translate/invariant_finder.py b/src/translate/invariant_finder.py index a8f32ffa5..9dded6dfa 100755 --- a/src/translate/invariant_finder.py +++ b/src/translate/invariant_finder.py @@ -3,6 +3,7 @@ from collections import deque, defaultdict import itertools +import random import time from typing import List @@ -13,7 +14,8 @@ class BalanceChecker: def __init__(self, task, reachable_action_params): - self.predicates_to_add_actions = defaultdict(set) + self.predicates_to_add_actions = defaultdict(list) + self.random = random.Random(314159) self.action_to_heavy_action = {} for act in task.actions: action = self.add_inequality_preconds(act, reachable_action_params) @@ -27,7 +29,9 @@ def __init__(self, task, reachable_action_params): too_heavy_effects.append(eff.copy()) if not eff.literal.negated: predicate = eff.literal.predicate - self.predicates_to_add_actions[predicate].add(action) + add_actions = self.predicates_to_add_actions[predicate] + if not add_actions or add_actions[-1] is not action: + add_actions.append(action) if create_heavy_act: heavy_act = pddl.Action(action.name, action.parameters, action.num_external_parameters, @@ -38,7 +42,7 @@ def __init__(self, task, reachable_action_params): self.action_to_heavy_action[action] = heavy_act def get_threats(self, predicate): - return self.predicates_to_add_actions.get(predicate, set()) + return self.predicates_to_add_actions.get(predicate, list()) def get_heavy_action(self, action): return self.action_to_heavy_action[action] @@ -79,9 +83,12 @@ def get_fluents(task): def get_initial_invariants(task): for predicate in get_fluents(task): all_args = list(range(len(predicate.arguments))) - for omitted_arg in [-1] + all_args: - order = [i for i in all_args if i != omitted_arg] - part = invariants.InvariantPart(predicate.name, order, omitted_arg) + part = invariants.InvariantPart(predicate.name, all_args, None) + yield invariants.Invariant((part,)) + for omitted in range(len(predicate.arguments)): + inv_args = (all_args[0:omitted] + [invariants.COUNTED] + + all_args[omitted:-1]) + part = invariants.InvariantPart(predicate.name, inv_args, omitted) yield invariants.Invariant((part,)) def find_invariants(task, reachable_action_params): @@ -112,25 +119,32 @@ def useful_groups(invariants, initial_facts): for predicate in invariant.predicates: predicate_to_invariants[predicate].append(invariant) - nonempty_groups = set() + nonempty_groups = dict() # dict instead of set because it is stable overcrowded_groups = set() for atom in initial_facts: if isinstance(atom, pddl.Assign): continue for invariant in predicate_to_invariants.get(atom.predicate, ()): - group_key = (invariant, tuple(invariant.get_parameters(atom))) + parameters = invariant.get_parameters(atom) + # we need to make the parameters dictionary hashable, so + # we store the values as a tuple + parameters_tuple = tuple(parameters[var] + for var in range(invariant.arity())) + + group_key = (invariant, parameters_tuple) if group_key not in nonempty_groups: - nonempty_groups.add(group_key) + nonempty_groups[group_key] = True else: overcrowded_groups.add(group_key) - useful_groups = nonempty_groups - overcrowded_groups + useful_groups = [group_key for group_key in nonempty_groups.keys() + if group_key not in overcrowded_groups] for (invariant, parameters) in useful_groups: yield [part.instantiate(parameters) for part in sorted(invariant.parts)] # returns a list of mutex groups (parameters instantiated, counted variables not) def get_groups(task, reachable_action_params=None) -> List[List[pddl.Atom]]: with timers.timing("Finding invariants", block=True): - invariants = sorted(find_invariants(task, reachable_action_params)) + invariants = list(find_invariants(task, reachable_action_params)) with timers.timing("Checking invariant weight"): result = list(useful_groups(invariants, task.init)) return result diff --git a/src/translate/invariants.py b/src/translate/invariants.py index ed98e4edc..a4d593a89 100644 --- a/src/translate/invariants.py +++ b/src/translate/invariants.py @@ -8,38 +8,35 @@ # Notes: # All parts of an invariant always use all non-counted variables # -> the arity of all predicates covered by an invariant is either the -# number of the invariant variables or this value + 1 +# number of the invariant parameters or this value + 1 # -# we currently keep the assumption that each predicate occurs at most once -# in every invariant. - -def invert_list(alist): - result = defaultdict(list) - for pos, arg in enumerate(alist): - result[arg].append(pos) - return result +# We only consider invariants where each predicate occurs in at most one part. +COUNTED = -1 def instantiate_factored_mapping(pairs): + """Input pairs is a list [(preimg1, img1), ..., (preimgn, imgn)]. + For entry (preimg, img), preimg is a list of numbers and img a list of + invariant parameters or COUNTED of the same length. All preimages (and + all images) are pairwise disjoint, as well as the components of each + preimage/image. + + The function determines all possible bijections between the union of + preimgs and the union of imgs such that for every entry (preimg, img), + all values from preimg are mapped to values from img. + It yields one permutation after the other, each represented as a list + of pairs (x,y), meaning x is mapped to y. + """ + # for every entry (preimg, img) in pairs, determine all possible bijections + # from preimg to img. part_mappings = [[list(zip(preimg, perm_img)) for perm_img in itertools.permutations(img)] for (preimg, img) in pairs] - return tools.cartesian_product(part_mappings) - - -def find_unique_variables(action, invariant): - # find unique names for invariant variables - params = {p.name for p in action.parameters} - for eff in action.effects: - params.update([p.name for p in eff.parameters]) - inv_vars = [] - counter = itertools.count() - for _ in range(invariant.arity()): - while True: - new_name = "?v%i" % next(counter) - if new_name not in params: - inv_vars.append(new_name) - break - return inv_vars + # all possibilities to pick one bijection for each entry + if not part_mappings: + yield [] + else: + for x in itertools.product(*part_mappings): + yield list(itertools.chain.from_iterable(x)) def get_literals(condition): @@ -54,17 +51,27 @@ def ensure_conjunction_sat(system, *parts): conjunction of all parts is satisfiable. Each part must be an iterator, generator, or an iterable over - literals.""" + literals. + + We add the following constraints for each literal to the system: + + - for (not (= x y)): x != y (as an InequalityDisjunction with one entry + (x,y)), + - for (= x y): x = y + - for predicates that occur with a positive and negative literal, we + consider every combination of a positive one (e.g. P(x, y, z)) and + a negative one (e.g. (not P(a, b, c))) and add a constraint + (x != a or y != b or z != c).""" pos = defaultdict(set) neg = defaultdict(set) for literal in itertools.chain(*parts): if literal.predicate == "=": # use (in)equalities in conditions if literal.negated: - n = constraints.NegativeClause([literal.args]) - system.add_negative_clause(n) + d = constraints.InequalityDisjunction([literal.args]) + system.add_inequality_disjunction(d) else: - a = constraints.Assignment([literal.args]) - system.add_assignment_disjunction([a]) + a = constraints.EqualityConjunction([literal.args]) + system.add_equality_DNF([a]) else: if literal.negated: neg[literal.predicate].add(literal) @@ -77,125 +84,192 @@ def ensure_conjunction_sat(system, *parts): for negatom in neg[pred]: parts = list(zip(negatom.args, posatom.args)) if parts: - negative_clause = constraints.NegativeClause(parts) - system.add_negative_clause(negative_clause) + ineq_disj = constraints.InequalityDisjunction(parts) + system.add_inequality_disjunction(ineq_disj) -def ensure_cover(system, literal, invariant, inv_vars): - """Modifies the constraint system such that it is only solvable if the - invariant covers the literal""" - a = invariant.get_covering_assignments(inv_vars, literal) - assert len(a) == 1 - # if invariants could contain several parts of one predicate, this would - # not be true but the depending code in parts relies on this assumption - system.add_assignment_disjunction(a) +def ensure_cover(system, literal, invariant): + """Modifies the constraint system such that in every solution the invariant + covers the literal (= invariant parameters are equivalent to the + corresponding argument in the literal).""" + cover = invariant._get_cover_equivalence_conjunction(literal) + system.add_equality_DNF([cover]) def ensure_inequality(system, literal1, literal2): """Modifies the constraint system such that it is only solvable if the literal instantiations are not equal (ignoring whether one is negated and - the other is not)""" - if (literal1.predicate == literal2.predicate and - literal1.args): + the other is not). + + If the literals have different predicates, there is nothing to do. + Otherwise we add for P(x, y, z) and P(a, b, c) a contraint + (x != a or y != b or z != c).""" + if (literal1.predicate == literal2.predicate and literal1.args): parts = list(zip(literal1.args, literal2.args)) - system.add_negative_clause(constraints.NegativeClause(parts)) + system.add_inequality_disjunction(constraints.InequalityDisjunction(parts)) class InvariantPart: - def __init__(self, predicate, order, omitted_pos=-1): + def __init__(self, predicate, args, omitted_pos=None): + """There is one InvariantPart for every predicate mentioned in the + invariant. The arguments args contain numbers 0,1,... for the + invariant parameters and COUNTED at the omitted position. + If no position is omitted, omitted_pos is None, otherwise it is the + index of COUNTED in args.""" self.predicate = predicate - self.order = order + self.args = tuple(args) self.omitted_pos = omitted_pos def __eq__(self, other): # This implies equality of the omitted_pos component. - return self.predicate == other.predicate and self.order == other.order + return self.predicate == other.predicate and self.args == other.args def __ne__(self, other): - return self.predicate != other.predicate or self.order != other.order + return self.predicate != other.predicate or self.args != other.args def __le__(self, other): - return self.predicate <= other.predicate or self.order <= other.order + return (self.predicate, self.args) <= (other.predicate, other.args) def __lt__(self, other): - return self.predicate < other.predicate or self.order < other.order + return (self.predicate, self.args) < (other.predicate, other.args) def __hash__(self): - return hash((self.predicate, tuple(self.order))) + return hash((self.predicate, self.args)) def __str__(self): - var_string = " ".join(map(str, self.order)) - omitted_string = "" - if self.omitted_pos != -1: - omitted_string = " [%d]" % self.omitted_pos - return "%s %s%s" % (self.predicate, var_string, omitted_string) + return f"{self.predicate}({self.args}) [omitted_pos = {self.omitted_pos}]" def arity(self): - return len(self.order) - - def get_assignment(self, parameters, literal): - equalities = [(arg, literal.args[argpos]) - for arg, argpos in zip(parameters, self.order)] - return constraints.Assignment(equalities) + if self.omitted_pos is None: + return len(self.args) + else: + return len(self.args) - 1 def get_parameters(self, literal): - return [literal.args[pos] for pos in self.order] - - def instantiate(self, parameters): - args = ["?X"] * (len(self.order) + (self.omitted_pos != -1)) - for arg, argpos in zip(parameters, self.order): - args[argpos] = arg + """Returns a dictionary, mapping the invariant parameters to the + corresponding values in the literal.""" + return dict((arg, literal.args[pos]) + for pos, arg in enumerate(self.args) + if pos != self.omitted_pos) + + def instantiate(self, parameters_tuple): + args = [parameters_tuple[arg] if arg != COUNTED else "?X" + for arg in self.args] return pddl.Atom(self.predicate, args) def possible_mappings(self, own_literal, other_literal): - allowed_omissions = len(other_literal.args) - len(self.order) + """This method is used when an action had an unbalanced add effect + own_literal. The action has a delete effect on literal + other_literal, so we try to refine the invariant such that it also + covers the delete effect. + + From own_literal, we can determine a variable or object for every + invariant parameter, where potentially several invariant parameters + can have the same value. + + From the arguments of other_literal, we determine all possibilities + how we can use the invariant parameters as arguments of + other_literal so that the values match (possibly covering one + parameter with a placeholder/counted variable). Since there also can + be duplicates in the argumets of other_literal, we cannot operate on + the arguments directly, but instead operate on the positions. + + The method returns [] if there is no possible mapping and otherwise + yields the mappings from the positions of other to the invariant + variables or COUNTED one by one. + """ + allowed_omissions = len(other_literal.args) - self.arity() + # All parts of an invariant always use all non-counted variables, of + # which we have arity many. So we must omit allowed_omissions many + # arguments of other_literal when matching invariant parameters with + # arguments. if allowed_omissions not in (0, 1): + # There may be at most one counted variable. return [] own_parameters = self.get_parameters(own_literal) - arg_to_ordered_pos = invert_list(own_parameters) - other_arg_to_pos = invert_list(other_literal.args) + # own_parameters is a dictionary mapping the invariant parameters to + # the corresponding parameter of own_literal + ownarg_to_invariant_parameters = defaultdict(list) + for x, y in own_parameters.items(): + ownarg_to_invariant_parameters[y].append(x) + + # other_arg_to_pos maps every argument of other_literal to the + # lists of positions in which it is occuring in other_literal, e.g. + # for P(?a, ?b, ?a), other_arg_to_pos["?a"] = [0, 2]. + other_arg_to_pos = defaultdict(list) + for pos, arg in enumerate(other_literal.args): + other_arg_to_pos[arg].append(pos) + factored_mapping = [] + # We iterate over all values occuring as arguments in other_literal + # and compare the number of occurrences in other_literal to those in + # own_literal. If the difference of these numbers allows us to cover + # other_literal with the (still) permitted number of counted variables, + # we store the correspondance of all argument positions of + # other_literal for this value to the invariant parameters at these + # positions in factored_mapping. If all values can be covered, we + # instatiate the complete factored_mapping, computing all possibilities + # to map positions from other_literal to invariant parameters (or + # COUNTED if the position is omitted). for key, other_positions in other_arg_to_pos.items(): - own_positions = arg_to_ordered_pos.get(key, []) - len_diff = len(own_positions) - len(other_positions) + inv_params = ownarg_to_invariant_parameters[key] + # all positions at which key occurs as argument in own_literal + len_diff = len(inv_params) - len(other_positions) if len_diff >= 1 or len_diff <= -2 or len_diff == -1 and not allowed_omissions: + # mapping of the literals is not possible with at most one + # counted variable. return [] if len_diff: - own_positions.append(-1) + inv_params.append(COUNTED) allowed_omissions = 0 - factored_mapping.append((other_positions, own_positions)) + factored_mapping.append((other_positions, inv_params)) return instantiate_factored_mapping(factored_mapping) def possible_matches(self, own_literal, other_literal): + """This method is used when an action had an unbalanced add effect + on own_literal. The action has a delete effect on literal + other_literal, so we try to refine the invariant such that it also + covers the delete effect. + + For this purpose, we consider all possible mappings from the + parameter positions of other_literal to the parameter positions of + own_literal such that the extended invariant can use other_literal + to balance own_literal. From these position mapping, we can extract + the new invariant part. + + Consider for an example of the "self" InvariantPart "forall ?@v0, + ?@v1, ?@v2 P(?@v0, ?@v1, ?@v2) is non-increasing" and let + own_literal be P(?a, ?b, ?c) and other_literal be Q(?b, ?c, ?d, ?a). + The only possible mapping from positions of Q to invariant variables + of P (or COUNTED) is [0->?@v1, 1->?@v2, 2->COUNTED, 3->?@v0] for + which we create a new Invariant Part Q(?@v1, ?@v2, _. ?@v0) with the + third argument being counted. + """ assert self.predicate == own_literal.predicate - result = [] for mapping in self.possible_mappings(own_literal, other_literal): - new_order = [None] * len(self.order) - omitted = -1 - for (key, value) in mapping: - if value == -1: - omitted = key + args = [COUNTED] * len(other_literal.args) + omitted = None + for (other_pos, inv_var) in mapping: + if inv_var == COUNTED: + omitted = other_pos else: - new_order[value] = key - result.append(InvariantPart(other_literal.predicate, new_order, omitted)) - return result - - def matches(self, other, own_literal, other_literal): - return self.get_parameters(own_literal) == other.get_parameters(other_literal) + args[other_pos] = inv_var + yield InvariantPart(other_literal.predicate, args, omitted) class Invariant: # An invariant is a logical expression of the type - # forall V1...Vk: sum_(part in parts) weight(part, V1, ..., Vk) <= 1. + # forall ?@v1...?@vk: sum_(part in parts) weight(part, ?@v1, ..., ?@vk) <= 1. # k is called the arity of the invariant. - # A "part" is a symbolic fact only variable symbols in {V1, ..., Vk, X}; - # the symbol X may occur at most once. + # A "part" is an atom that only contains arguments from {?@v1, ..., ?@vk, + # COUNTED} but instead of ?@vi, we store it as int i; COUNTED may occur at + # most once. def __init__(self, parts): self.parts = frozenset(parts) - self.predicates = {part.predicate for part in parts} self.predicate_to_part = {part.predicate: part for part in parts} + self.predicates = set(self.predicate_to_part.keys()) assert len(self.parts) == len(self.predicates) def __eq__(self, other): @@ -204,17 +278,11 @@ def __eq__(self, other): def __ne__(self, other): return self.parts != other.parts - def __lt__(self, other): - return self.parts < other.parts - - def __le__(self, other): - return self.parts <= other.parts - def __hash__(self): return hash(self.parts) def __str__(self): - return "{%s}" % ", ".join(str(part) for part in self.parts) + return "{%s}" % ", ".join(sorted(str(part) for part in self.parts)) def __repr__(self): return '' % self @@ -228,30 +296,57 @@ def get_parameters(self, atom): def instantiate(self, parameters): return [part.instantiate(parameters) for part in self.parts] - def get_covering_assignments(self, parameters, atom): - part = self.predicate_to_part[atom.predicate] - return [part.get_assignment(parameters, atom)] - # if there were more parts for the same predicate the list - # contained more than one element + def _get_cover_equivalence_conjunction(self, literal): + """This is only called for atoms with a predicate for which the + invariant has a part. It returns an equivalence conjunction that + requires every invariant parameter to be equal to the corresponding + argument of the given literal. For the result, we do not consider + whether the literal is negated. + + Example: If the literal is P(?a, ?b, ?c), the invariant part for P + is P(?@v0, _, ?@v1) then the method returns the constraint (?@v0 = ?a + and ?@v1 = ?c). + """ + part = self.predicate_to_part[literal.predicate] + equalities = [(arg, literal.args[pos]) + for pos, arg in enumerate(part.args) + if arg != COUNTED] + return constraints.EqualityConjunction(equalities) + # If there were more parts for the same predicate, we would have to + # consider more than one assignment (disjunctively). + # We assert earlier that this is not the case. def check_balance(self, balance_checker, enqueue_func): # Check balance for this hypothesis. - actions_to_check = set() - for part in self.parts: - actions_to_check |= balance_checker.get_threats(part.predicate) - for action in actions_to_check: + actions_to_check = dict() + # We will only use the keys of the dictionary. We do not use a set + # because it's not stable and introduces non-determinism in the + # invariance analysis. + for part in sorted(self.parts): + for a in balance_checker.get_threats(part.predicate): + actions_to_check[a] = True + + actions = list(actions_to_check.keys()) + while actions: + # For a better expected perfomance, we want to randomize the order + # in which actions are checked. Since candidates are often already + # discarded by an early check, we do not want to shuffle the order + # but instead always draw the next action randomly from those we + # did not yet consider. + pos = balance_checker.random.randrange(len(actions)) + actions[pos], actions[-1] = actions[-1], actions[pos] + action = actions.pop() heavy_action = balance_checker.get_heavy_action(action) - if self.operator_too_heavy(heavy_action): + if self._operator_too_heavy(heavy_action): return False - if self.operator_unbalanced(action, enqueue_func): + if self._operator_unbalanced(action, enqueue_func): return False return True - def operator_too_heavy(self, h_action): + def _operator_too_heavy(self, h_action): add_effects = [eff for eff in h_action.effects if not eff.literal.negated and self.predicate_to_part.get(eff.literal.predicate)] - inv_vars = find_unique_variables(h_action, self) if len(add_effects) <= 1: return False @@ -259,8 +354,8 @@ def operator_too_heavy(self, h_action): for eff1, eff2 in itertools.combinations(add_effects, 2): system = constraints.ConstraintSystem() ensure_inequality(system, eff1.literal, eff2.literal) - ensure_cover(system, eff1.literal, self, inv_vars) - ensure_cover(system, eff2.literal, self, inv_vars) + ensure_cover(system, eff1.literal, self) + ensure_cover(system, eff2.literal, self) ensure_conjunction_sat(system, get_literals(h_action.precondition), get_literals(eff1.condition), get_literals(eff2.condition), @@ -270,8 +365,7 @@ def operator_too_heavy(self, h_action): return True return False - def operator_unbalanced(self, action, enqueue_func): - inv_vars = find_unique_variables(action, self) + def _operator_unbalanced(self, action, enqueue_func): relevant_effs = [eff for eff in action.effects if self.predicate_to_part.get(eff.literal.predicate)] add_effects = [eff for eff in relevant_effs @@ -279,59 +373,80 @@ def operator_unbalanced(self, action, enqueue_func): del_effects = [eff for eff in relevant_effs if eff.literal.negated] for eff in add_effects: - if self.add_effect_unbalanced(action, eff, del_effects, inv_vars, - enqueue_func): + if self._add_effect_unbalanced(action, eff, del_effects, + enqueue_func): return True return False - def minimal_covering_renamings(self, action, add_effect, inv_vars): - """computes the minimal renamings of the action parameters such - that the add effect is covered by the action. - Each renaming is an constraint system""" - - # add_effect must be covered - assigs = self.get_covering_assignments(inv_vars, add_effect.literal) - - # renaming of operator parameters must be minimal - minimal_renamings = [] - params = [p.name for p in action.parameters] - for assignment in assigs: - system = constraints.ConstraintSystem() - system.add_assignment(assignment) - mapping = assignment.get_mapping() - if len(params) > 1: - for (n1, n2) in itertools.combinations(params, 2): - if mapping.get(n1, n1) != mapping.get(n2, n2): - negative_clause = constraints.NegativeClause([(n1, n2)]) - system.add_negative_clause(negative_clause) - minimal_renamings.append(system) - return minimal_renamings - - def add_effect_unbalanced(self, action, add_effect, del_effects, - inv_vars, enqueue_func): - - minimal_renamings = self.minimal_covering_renamings(action, add_effect, - inv_vars) - - lhs_by_pred = defaultdict(list) + def _add_effect_unbalanced(self, action, add_effect, del_effects, + enqueue_func): + # We build for every delete effect that is possibly covered by this + # invariant a constraint system that will be solvable if the delete + # effect balances the add effect. Large parts of the constraint system + # are independent of the delete effect, so we precompute them first. + + # Dictionary add_effect_produced_by_pred describes what must be true so + # that the action is applicable and produces the add effect. It is + # stored as a map from predicate names to literals (overall + # representing a conjunction of these literals). + add_effect_produced_by_pred = defaultdict(list) for lit in itertools.chain(get_literals(action.precondition), get_literals(add_effect.condition), get_literals(add_effect.literal.negate())): - lhs_by_pred[lit.predicate].append(lit) + add_effect_produced_by_pred[lit.predicate].append(lit) + + # add_cover is an equality conjunction that sets each invariant + # parameter equal to its value in add_effect.literal. + add_cover = self._get_cover_equivalence_conjunction(add_effect.literal) + + # add_cover can imply equivalences between variables (and with + # constants). For example if the invariant part is P(_ ?@v0 ?@v1 ?@v2) + # and the add effect is P(?x ?y ?y a) then we would know that the + # invariant part is only threatened by the add effect if the first two + # invariant parameters are equal and the third parameter is a. + + # The add effect must be balanced in all threatening action + # applications. We thus must adapt the constraint system such that it + # prevents restricting solution that set action parameters or + # quantified variables of the add effect equal to each other or to + # a specific constant if this is not already implied by the threat. + params = [p.name for p in itertools.chain(action.parameters, + add_effect.parameters)] + param_system = constraints.ConstraintSystem() + representative = add_cover.get_representative() + # Dictionary representative maps every term to its representative in + # the finest equivalence relation induced by the equalities in + # add_cover. If the equivalence class contains an object, the + # representative is an object. + for param in params: + r = representative.get(param, param) + if isinstance(r, int) or r[0] == "?": + # for the add effect being a threat to the invariant, param + # does not need to be a specific constant. So we may not bind + # it to a constant when balancing the add effect. We store this + # information here. + param_system.add_not_constant(param) + for (n1, n2) in itertools.combinations(params, 2): + if representative.get(n1, n1) != representative.get(n2, n2): + # n1 and n2 don't have to be equivalent to cover the add + # effect, so we require for the solutions that they do not + # make n1 and n2 equvalent. + ineq_disj = constraints.InequalityDisjunction([(n1, n2)]) + param_system.add_inequality_disjunction(ineq_disj) for del_effect in del_effects: - minimal_renamings = self.unbalanced_renamings( - del_effect, add_effect, inv_vars, lhs_by_pred, minimal_renamings) - if not minimal_renamings: + if self._balances(del_effect, add_effect, + add_effect_produced_by_pred, add_cover, + param_system): return False - # Otherwise, the balance check fails => Generate new candidates. - self.refine_candidate(add_effect, action, enqueue_func) + # The balance check failed => Generate new candidates. + self._refine_candidate(add_effect, action, enqueue_func) return True - def refine_candidate(self, add_effect, action, enqueue_func): - """refines the candidate for an add effect that is unbalanced in the - action and adds the refined one to the queue""" + def _refine_candidate(self, add_effect, action, enqueue_func): + """Refines the candidate for an add effect that is unbalanced in the + action and adds the refined one to the queue.""" part = self.predicate_to_part[add_effect.literal.predicate] for del_eff in [eff for eff in action.effects if eff.literal.negated]: if del_eff.literal.predicate not in self.predicate_to_part: @@ -339,75 +454,79 @@ def refine_candidate(self, add_effect, action, enqueue_func): del_eff.literal): enqueue_func(Invariant(self.parts.union((match,)))) - def unbalanced_renamings(self, del_effect, add_effect, inv_vars, - lhs_by_pred, unbalanced_renamings): - """returns the renamings from unbalanced renamings for which - the del_effect does not balance the add_effect.""" + def _balances(self, del_effect, add_effect, produced_by_pred, + add_cover, param_system): + """Returns whether the del_effect is guaranteed to balance the add effect + where the input is such that: + - produced_by_pred must be true for the add_effect to be produced, + - add_cover is an equality conjunction that sets each invariant + parameter equal to its value in add_effect. These equivalences + must be true for the add effect threatening the invariant. + - param_system contains contraints that action and add_effect + parameters are not fixed to be equivalent or a certain constant + (except the add effect is otherwise not threat).""" + + balance_system = self._balance_system(add_effect, del_effect, + produced_by_pred) + if not balance_system: + # it is impossible to guarantee that every production by add_effect + # implies a consumption by del effect. + return False + # We will overall build a system that is solvable if the delete effect + # is guaranteed to balance the add effect for this invariant. system = constraints.ConstraintSystem() - ensure_cover(system, del_effect.literal, self, inv_vars) - - # Since we may only rename the quantified variables of the delete effect - # we need to check that "renamings" of constants are already implied by - # the unbalanced_renaming (of the of the operator parameters). The - # following system is used as a helper for this. It builds a conjunction - # that formulates that the constants are NOT renamed accordingly. We - # below check that this is impossible with each unbalanced renaming. - check_constants = False - constant_test_system = constraints.ConstraintSystem() - for a, b in system.combinatorial_assignments[0][0].equalities: - # first 0 because the system was empty before we called ensure_cover - # second 0 because ensure_cover only adds assignments with one entry - if b[0] != "?": - check_constants = True - neg_clause = constraints.NegativeClause([(a, b)]) - constant_test_system.add_negative_clause(neg_clause) + system.add_equality_conjunction(add_cover) + # In every solution, the invariant parameters must equal the + # corresponding arguments of the add effect atom. - ensure_inequality(system, add_effect.literal, del_effect.literal) + ensure_cover(system, del_effect.literal, self) + # In every solution, the invariant parameters must equal the + # corresponding arguments of the delete effect atom. + + system.extend(balance_system) + # In every solution a production by the add effect guarantees + # a consumption by the delete effect. + + system.extend(param_system) + # A solution may not restrict action parameters (must be balanced + # independent of the concrete action instantiation). - still_unbalanced = [] - for renaming in unbalanced_renamings: - if check_constants: - new_sys = constant_test_system.combine(renaming) - if new_sys.is_solvable(): - # it is possible that the operator arguments are not - # mapped to constants as required for covering the delete - # effect - still_unbalanced.append(renaming) - continue - - new_sys = system.combine(renaming) - if self.lhs_satisfiable(renaming, lhs_by_pred): - implies_system = self.imply_del_effect(del_effect, lhs_by_pred) - if not implies_system: - still_unbalanced.append(renaming) - continue - new_sys = new_sys.combine(implies_system) - if not new_sys.is_solvable(): - still_unbalanced.append(renaming) - return still_unbalanced - - def lhs_satisfiable(self, renaming, lhs_by_pred): - system = renaming.copy() - ensure_conjunction_sat(system, *itertools.chain(lhs_by_pred.values())) - return system.is_solvable() - - def imply_del_effect(self, del_effect, lhs_by_pred): - """returns a constraint system that is solvable if lhs implies - the del effect (only if lhs is satisfiable). If a solvable - lhs never implies the del effect, return None.""" - # del_effect.cond and del_effect.atom must be implied by lhs - implies_system = constraints.ConstraintSystem() + if not system.is_solvable(): + return False + return True + + def _balance_system(self, add_effect, del_effect, literals_by_pred): + """Returns a constraint system that is solvable if + - the conjunction of literals occurring as values in dictionary + literals_by_pred (characterizing a threat for the invariant + through an actual production by add_effect) implies the + consumption of the atom of the delete effect, and + - the produced and consumed atom are different (otherwise by + add-after-delete semantics, the delete effect would not balance + the add effect). + + We return None if we detect that the constraint system would never + be solvable (by an incomplete cheap test). + """ + system = constraints.ConstraintSystem() for literal in itertools.chain(get_literals(del_effect.condition), [del_effect.literal.negate()]): - poss_assignments = [] - for match in lhs_by_pred[literal.predicate]: - if match.negated != literal.negated: - continue - else: - a = constraints.Assignment(list(zip(literal.args, match.args))) - poss_assignments.append(a) - if not poss_assignments: + possibilities = [] + # possible equality conjunctions that establish that the literals + # in literals_by_pred logically imply the current literal. + for match in literals_by_pred[literal.predicate]: + if match.negated == literal.negated: + # match implies literal iff they agree on each argument + ec = constraints.EqualityConjunction(list(zip(literal.args, + match.args))) + possibilities.append(ec) + if not possibilities: return None - implies_system.add_assignment_disjunction(poss_assignments) - return implies_system + system.add_equality_DNF(possibilities) + + # if the add effect and the delete effect affect the same predicate + # then their arguments must differ in at least one position (because of + # the add-after-delete semantics). + ensure_inequality(system, add_effect.literal, del_effect.literal) + return system diff --git a/src/translate/normalize.py b/src/translate/normalize.py index 375dc67e9..4b5df01dd 100755 --- a/src/translate/normalize.py +++ b/src/translate/normalize.py @@ -1,6 +1,7 @@ #! /usr/bin/env python3 import copy +from typing import Sequence import pddl @@ -23,7 +24,19 @@ def delete_owner(self, task): def build_rules(self, rules): action = self.owner rule_head = get_action_predicate(action) - rule_body = condition_to_rule_body(action.parameters, self.condition) + + # If the action cost is based on a primitive numeric expression, + # we need to require that it has a value defined in the initial state. + # We hand it over to condition_to_rule_body to include this in the rule + # body. + pne = None + if (isinstance(action.cost, pddl.Increase) and + isinstance(action.cost.expression, + pddl.PrimitiveNumericExpression)): + pne = action.cost.expression + + rule_body = condition_to_rule_body(action.parameters, self.condition, + pne) rules.append((rule_body, rule_head)) def get_type_map(self): return self.owner.type_map @@ -117,6 +130,9 @@ def get_axiom_predicate(axiom): variables += [par.name for par in axiom.condition.parameters] return pddl.Atom(name, variables) +def get_pne_definition_predicate(pne: pddl.PrimitiveNumericExpression): + return pddl.Atom(f"@def-{pne.symbol}", pne.args) + def all_conditions(task): for action in task.actions: yield PreconditionProxy(action) @@ -366,10 +382,24 @@ def build_exploration_rules(task): proxy.build_rules(result) return result -def condition_to_rule_body(parameters, condition): +def condition_to_rule_body(parameters: Sequence[pddl.TypedObject], + condition: pddl.conditions.Condition, + pne: pddl.PrimitiveNumericExpression = None): + """The rule body requires that + - all parameters (including existentially quantified variables in the + condition) are instantiated with objecst of the right type, + - all positive atoms in the condition (which must be normalized) are + true in the Prolog model, and + - the primitive numeric expression (from the action cost) has a defined + value (in the initial state).""" result = [] + # Require parameters to be instantiated with objects of the right type. for par in parameters: result.append(par.get_atom()) + + # Require each positive literal in the condition to be reached and + # existentially quantified variables of the condition to be instantiated + # with objects of the right type. if not isinstance(condition, pddl.Truth): if isinstance(condition, pddl.ExistentialCondition): for par in condition.parameters: @@ -388,6 +418,11 @@ def condition_to_rule_body(parameters, condition): assert isinstance(part, pddl.Literal), "Condition not normalized: %r" % part if not part.negated: result.append(part) + + # Require the primitive numeric expression (from the action cost) to be + # defined. + if pne is not None: + result.append(get_pne_definition_predicate(pne)) return result if __name__ == "__main__": diff --git a/src/translate/pddl/tasks.py b/src/translate/pddl/tasks.py index 6663a95af..0e28115ee 100644 --- a/src/translate/pddl/tasks.py +++ b/src/translate/pddl/tasks.py @@ -68,15 +68,22 @@ def dump(self): for axiom in self.axioms: axiom.dump() + +REQUIREMENT_LABELS = [ + ":strips", ":adl", ":typing", ":negation", ":equality", + ":negative-preconditions", ":disjunctive-preconditions", + ":existential-preconditions", ":universal-preconditions", + ":quantified-preconditions", ":conditional-effects", + ":derived-predicates", ":action-costs" +] + + class Requirements: def __init__(self, requirements: List[str]): self.requirements = requirements for req in requirements: - assert req in ( - ":strips", ":adl", ":typing", ":negation", ":equality", - ":negative-preconditions", ":disjunctive-preconditions", - ":existential-preconditions", ":universal-preconditions", - ":quantified-preconditions", ":conditional-effects", - ":derived-predicates", ":action-costs"), req + if req not in REQUIREMENT_LABELS: + raise ValueError(f"Invalid requirement. Got: {req}\n" + f"Expected: {', '.join(REQUIREMENT_LABELS)}") def __str__(self): return ", ".join(self.requirements) diff --git a/src/translate/pddl_parser/__init__.py b/src/translate/pddl_parser/__init__.py index 32f518658..c6290588b 100644 --- a/src/translate/pddl_parser/__init__.py +++ b/src/translate/pddl_parser/__init__.py @@ -1 +1,2 @@ +from .parse_error import ParseError from .pddl_file import open diff --git a/src/translate/pddl_parser/lisp_parser.py b/src/translate/pddl_parser/lisp_parser.py index 271cb646b..872b29784 100644 --- a/src/translate/pddl_parser/lisp_parser.py +++ b/src/translate/pddl_parser/lisp_parser.py @@ -1,20 +1,18 @@ -__all__ = ["ParseError", "parse_nested_list"] +__all__ = ["parse_nested_list"] -class ParseError(Exception): - def __init__(self, value): - self.value = value - def __str__(self): - return self.value +from .parse_error import ParseError # Basic functions for parsing PDDL (Lisp) files. def parse_nested_list(input_file): tokens = tokenize(input_file) next_token = next(tokens) if next_token != "(": - raise ParseError("Expected '(', got %s." % next_token) + raise ParseError(f"Expected '(', got '{next_token}'.") result = list(parse_list_aux(tokens)) - for tok in tokens: # Check that generator is exhausted. - raise ParseError("Unexpected token: %s." % tok) + remaining_tokens = list(tokens) + if remaining_tokens: + raise ParseError(f"Tokens remaining after parsing: " + f"{' '.join(remaining_tokens)}") return result def tokenize(input): @@ -23,8 +21,7 @@ def tokenize(input): try: line.encode("ascii") except UnicodeEncodeError: - raise ParseError("Non-ASCII character outside comment: %s" % - line[0:-1]) + raise ParseError(f"Non-ASCII character outside comment: {line[0:-1]}") line = line.replace("(", " ( ").replace(")", " ) ").replace("?", " ?") for token in line.split(): yield token.lower() diff --git a/src/translate/pddl_parser/parse_error.py b/src/translate/pddl_parser/parse_error.py new file mode 100644 index 000000000..831cca9f0 --- /dev/null +++ b/src/translate/pddl_parser/parse_error.py @@ -0,0 +1,2 @@ +class ParseError(Exception): + pass diff --git a/src/translate/pddl_parser/parsing_functions.py b/src/translate/pddl_parser/parsing_functions.py index fdc0b9dc8..e21e81b9d 100644 --- a/src/translate/pddl_parser/parsing_functions.py +++ b/src/translate/pddl_parser/parsing_functions.py @@ -1,31 +1,142 @@ +import contextlib import sys import graph import pddl +from .parse_error import ParseError +TYPED_LIST_SEPARATOR = "-" -def parse_typed_list(alist, only_variables=False, - constructor=pddl.TypedObject, +SYNTAX_LITERAL = "(PREDICATE ARGUMENTS*)" +SYNTAX_LITERAL_NEGATED = "(not (PREDICATE ARGUMENTS*))" +SYNTAX_LITERAL_POSSIBLY_NEGATED = f"{SYNTAX_LITERAL} or {SYNTAX_LITERAL_NEGATED}" + +SYNTAX_PREDICATE = "(PREDICATE_NAME [VARIABLE [- TYPE]?]*)" +SYNTAX_PREDICATES = f"(:predicates {SYNTAX_PREDICATE}*)" +SYNTAX_FUNCTION = "(FUNCTION_NAME [VARIABLE [- TYPE]?]*)" +SYNTAX_ACTION = "(:action NAME [:parameters PARAMETERS]? " \ + "[:precondition PRECONDITION]? :effect EFFECT)" +SYNTAX_AXIOM = "(:derived PREDICATE CONDITION)" +SYNTAX_GOAL = "(:goal GOAL)" + +SYNTAX_CONDITION_AND = "(and CONDITION*)" +SYNTAX_CONDITION_OR = "(or CONDITION*)" +SYNTAX_CONDITION_IMPLY = "(imply CONDITION CONDITION)" +SYNTAX_CONDITION_NOT = "(not CONDITION)" +SYNTAX_CONDITION_FORALL_EXISTS = "({forall, exists} VARIABLES CONDITION)" + +SYNTAX_EFFECT_FORALL = "(forall VARIABLES EFFECT)" +SYNTAX_EFFECT_WHEN = "(when CONDITION EFFECT)" +SYNTAX_EFFECT_INCREASE = "(increase (total-cost) ASSIGNMENT)" + +SYNTAX_EXPRESSION = "POSITIVE_NUMBER or (FUNCTION VARIABLES*)" +SYNTAX_ASSIGNMENT = "({=,increase} EXPRESSION EXPRESSION)" + +SYNTAX_DOMAIN_DOMAIN_NAME = "(domain NAME)" +SYNTAX_TASK_PROBLEM_NAME = "(problem NAME)" +SYNTAX_TASK_DOMAIN_NAME = "(:domain NAME)" +SYNTAX_METRIC = "(:metric minimize (total-cost))" + + +CONDITION_TAG_TO_SYNTAX = { + "and": SYNTAX_CONDITION_AND, + "or": SYNTAX_CONDITION_OR, + "imply": SYNTAX_CONDITION_IMPLY, + "not": SYNTAX_CONDITION_NOT, + "forall": SYNTAX_CONDITION_FORALL_EXISTS, +} + + +class Context: + def __init__(self): + self._traceback = [] + + def __str__(self) -> str: + return "\n\t->".join(self._traceback) + + def error(self, message, item=None, syntax=None): + error_msg = f"{self}\n{message}" + if syntax: + error_msg += f"\nSyntax: {syntax}" + if item: + error_msg += f"\nGot: {item}" + raise ParseError(error_msg) + + def expected_word_error(self, name, *args, **kwargs): + self.error(f"{name} is expected to be a word.", *args, **kwargs) + + def expected_list_error(self, name, *args, **kwargs): + self.error(f"{name} is expected to be a block.", *args, **kwargs) + + def expected_named_block_error(self, alist, expected, *args, **kwargs): + self.error(f"Expected a non-empty block starting with any of the " + f"following words: {', '.join(expected)}", + item=alist, *args, **kwargs) + + @contextlib.contextmanager + def layer(self, message: str): + self._traceback.append(message) + yield + assert self._traceback.pop() == message + + +def check_named_block(alist, names): + return isinstance(alist, list) and alist and alist[0] in names + +def assert_named_block(context, alist, names): + if not check_named_block(alist, names): + context.expected_named_block_error(alist, names) + +def construct_typed_object(context, name, _type): + with context.layer("Parsing typed object"): + if not isinstance(name, str): + context.expected_word_error("Name of typed object", name) + return pddl.TypedObject(name, _type) + + +def construct_type(context, curr_type, base_type): + with context.layer("Parsing PDDL type"): + if not isinstance(curr_type, str): + context.expected_word_error("PDDL type", curr_type) + if not isinstance(base_type, str): + context.expected_word_error("Base type", base_type) + return pddl.Type(curr_type, base_type) + + +def parse_typed_list(context, alist, only_variables=False, + constructor=construct_typed_object, default_type="object"): - result = [] - while alist: - try: - separator_position = alist.index("-") - except ValueError: - items = alist - _type = default_type - alist = [] - else: - items = alist[:separator_position] - _type = alist[separator_position + 1] - alist = alist[separator_position + 2:] - for item in items: - assert not only_variables or item.startswith("?"), \ - "Expected item to be a variable: %s in (%s)" % ( - item, " ".join(items)) - entry = constructor(item, _type) - result.append(entry) - return result + with context.layer("Parsing typed list"): + result = [] + group_number = 1 + while alist: + with context.layer(f"Parsing {group_number}. group of typed list"): + try: + separator_position = alist.index(TYPED_LIST_SEPARATOR) + except ValueError: + items = alist + _type = default_type + alist = [] + else: + if separator_position == len(alist) - 1: + context.error( + f"Type missing after '{TYPED_LIST_SEPARATOR}'.", + alist) + items = alist[:separator_position] + _type = alist[separator_position + 1] + alist = alist[separator_position + 2:] + if not (isinstance(_type, str) or + (_type and _type[0] == "either" and + all(isinstance(_sub_type, str) for _sub_type in _type[1:]))): + context.error("Type value is expected to be a single word " + "or '(either WORD*)") + for item in items: + if only_variables and not item.startswith("?"): + context.error("Expected item to be a variable", item) + entry = constructor(context, item, _type) + result.append(entry) + group_number += 1 + return result def set_supertypes(type_list): @@ -42,49 +153,111 @@ def set_supertypes(type_list): type_name_to_type[desc_name].supertype_names.append(anc_name) -def parse_predicate(alist): - name = alist[0] - arguments = parse_typed_list(alist[1:], only_variables=True) +def parse_requirements(context, alist): + with context.layer("Parsing requirements"): + for item in alist: + if not isinstance(item, str): + context.expected_word_error("Requirement label", item) + try: + return pddl.Requirements(alist) + except ValueError as e: + context.error(f"Error in requirements.\n" + f"Reason: {e}") + + +def parse_predicate(context, alist): + with context.layer("Parsing predicate name"): + if not alist: + context.error("Predicate name missing", syntax=SYNTAX_PREDICATE) + name = alist[0] + if not isinstance(name, str): + context.expected_word_error("Predicate name", name) + with context.layer(f"Parsing arguments of predicate '{name}'"): + arguments = parse_typed_list(context, alist[1:], only_variables=True) return pddl.Predicate(name, arguments) -def parse_function(alist, type_name): - name = alist[0] - arguments = parse_typed_list(alist[1:]) +def parse_predicates(context, alist): + with context.layer("Parsing predicates"): + the_predicates = [] + for no, entry in enumerate(alist): + with context.layer(f"Parsing {no}. predicate"): + if not isinstance(entry, list): + context.error("Invalid predicate definition.", + syntax=SYNTAX_PREDICATE) + the_predicates.append(parse_predicate(context, entry)) + return the_predicates + + +def parse_function(context, alist, type_name): + with context.layer("Parsing function name"): + if not isinstance(alist, list) or len(alist) == 0: + context.error("Invalid definition of function.", + syntax=SYNTAX_FUNCTION) + name = alist[0] + if not isinstance(name, str): + context.expected_word_error("Function name", name) + with context.layer(f"Parsing function '{name}'"): + arguments = parse_typed_list(context, alist[1:]) + if not isinstance(type_name, str): + context.expected_word_error("Function type", type_name) return pddl.Function(name, arguments, type_name) -def parse_condition(alist, type_dict, predicate_dict): - condition = parse_condition_aux(alist, False, type_dict, predicate_dict) - return condition.uniquify_variables({}).simplified() +def parse_condition(context, alist, type_dict, predicate_dict): + with context.layer("Parsing condition"): + condition = parse_condition_aux( + context, alist, False, type_dict, predicate_dict) + return condition.uniquify_variables({}).simplified() -def parse_condition_aux(alist, negated, type_dict, predicate_dict): +def parse_condition_aux(context, alist, negated, type_dict, predicate_dict): """Parse a PDDL condition. The condition is translated into NNF on the fly.""" + if not alist: + context.error("Expected a non-empty block as condition.") tag = alist[0] if tag in ("and", "or", "not", "imply"): args = alist[1:] if tag == "imply": - assert len(args) == 2 + if len(args) != 2: + context.error("'imply' expects exactly two arguments.", + syntax=SYNTAX_CONDITION_IMPLY) if tag == "not": - assert len(args) == 1 - return parse_condition_aux( - args[0], not negated, type_dict, predicate_dict) + if len(args) != 1: + context.error("'not' expects exactly one argument.", + syntax=SYNTAX_CONDITION_NOT) + negated = not negated elif tag in ("forall", "exists"): - parameters = parse_typed_list(alist[1]) - args = alist[2:] - assert len(args) == 1 + if len(alist) != 3: + context.error("'forall' and 'exists' expect exactly two arguments.", + syntax=SYNTAX_CONDITION_FORALL_EXISTS) + if not isinstance(alist[1], list) or not alist[1]: + context.error( + "The first argument (VARIABLES) of 'forall' and 'exists' is " + "expected to be a non-empty block.", + syntax=SYNTAX_CONDITION_FORALL_EXISTS + ) + parameters = parse_typed_list(context, alist[1]) + args = [alist[2]] + elif tag in predicate_dict: + return parse_literal(context, alist, type_dict, predicate_dict, negated=negated) else: - return parse_literal(alist, type_dict, predicate_dict, negated=negated) + context.error("Expected logical operator or predicate name", tag) + + for nb_arg, arg in enumerate(args, start=1): + if not isinstance(arg, list) or not arg: + context.error( + f"'{tag}' expects as {nb_arg}. argument a non-empty block.", + item=arg, syntax=CONDITION_TAG_TO_SYNTAX[tag]) if tag == "imply": parts = [parse_condition_aux( - args[0], not negated, type_dict, predicate_dict), + context, args[0], not negated, type_dict, predicate_dict), parse_condition_aux( - args[1], negated, type_dict, predicate_dict)] + context, args[1], negated, type_dict, predicate_dict)] tag = "or" else: - parts = [parse_condition_aux(part, negated, type_dict, predicate_dict) + parts = [parse_condition_aux(context, part, negated, type_dict, predicate_dict) for part in args] if tag == "and" and not negated or tag == "or" and negated: @@ -95,36 +268,52 @@ def parse_condition_aux(alist, negated, type_dict, predicate_dict): return pddl.UniversalCondition(parameters, parts) elif tag == "exists" and not negated or tag == "forall" and negated: return pddl.ExistentialCondition(parameters, parts) - - -def parse_literal(alist, type_dict, predicate_dict, negated=False): - if alist[0] == "not": - assert len(alist) == 2 - alist = alist[1] - negated = not negated - - pred_id, arity = _get_predicate_id_and_arity( - alist[0], type_dict, predicate_dict) - - if arity != len(alist) - 1: - raise SystemExit("predicate used with wrong arity: (%s)" - % " ".join(alist)) - - if negated: - return pddl.NegatedAtom(pred_id, alist[1:]) - else: - return pddl.Atom(pred_id, alist[1:]) + elif tag == "not": + return parts[0] + + +def parse_literal(context, alist, type_dict, predicate_dict, negated=False): + with context.layer("Parsing literal"): + if not alist: + context.error("Literal definition has to be a non-empty block.", + alist, syntax=SYNTAX_LITERAL_POSSIBLY_NEGATED) + if alist[0] == "not": + if len(alist) != 2: + context.error( + "Negated literal definition has to have exactly one block as argument.", + alist, syntax=SYNTAX_LITERAL_NEGATED) + alist = alist[1] + if not isinstance(alist, list) or not alist: + context.error( + "Definition of negated literal has to be a non-empty block.", + alist, syntax=SYNTAX_LITERAL) + negated = not negated + + predicate_name = alist[0] + if not isinstance(predicate_name, str): + context.expected_word_error("Predicate name", predicate_name) + pred_id, arity = _get_predicate_id_and_arity( + context, predicate_name, type_dict, predicate_dict) + + if arity != len(alist) - 1: + context.error(f"Predicate '{predicate_name}' of arity {arity} used" + f" with {len(alist) -1} arguments.", alist) + + if negated: + return pddl.NegatedAtom(pred_id, alist[1:]) + else: + return pddl.Atom(pred_id, alist[1:]) SEEN_WARNING_TYPE_PREDICATE_NAME_CLASH = False -def _get_predicate_id_and_arity(text, type_dict, predicate_dict): +def _get_predicate_id_and_arity(context, text, type_dict, predicate_dict): global SEEN_WARNING_TYPE_PREDICATE_NAME_CLASH the_type = type_dict.get(text) the_predicate = predicate_dict.get(text) if the_type is None and the_predicate is None: - raise SystemExit("Undeclared predicate: %s" % text) + context.error("Undeclared predicate", text) elif the_predicate is not None: if the_type is not None and not SEEN_WARNING_TYPE_PREDICATE_NAME_CLASH: msg = ("Warning: name clash between type and predicate %r.\n" @@ -137,16 +326,18 @@ def _get_predicate_id_and_arity(text, type_dict, predicate_dict): return the_type.get_predicate_name(), 1 -def parse_effects(alist, result, type_dict, predicate_dict): +def parse_effects(context, alist, result, type_dict, predicate_dict): """Parse a PDDL effect (any combination of simple, conjunctive, conditional, and universal).""" - tmp_effect = parse_effect(alist, type_dict, predicate_dict) - normalized = tmp_effect.normalize() - cost_eff, rest_effect = normalized.extract_cost() - add_effect(rest_effect, result) - if cost_eff: - return cost_eff.effect - else: - return None + with context.layer("Parsing effect"): + tmp_effect = parse_effect(context, alist, type_dict, predicate_dict) + normalized = tmp_effect.normalize() + cost_eff, rest_effect = normalized.extract_cost() + add_effect(rest_effect, result) + if cost_eff: + return cost_eff.effect + else: + return None + def add_effect(tmp_effect, result): """tmp_effect has the following structure: @@ -188,93 +379,153 @@ def add_effect(tmp_effect, result): result.remove(contradiction) result.append(new_effect) -def parse_effect(alist, type_dict, predicate_dict): + +def parse_effect(context, alist, type_dict, predicate_dict): + if not alist: + context.error("All (sub-)effects have to be a non-empty blocks.", alist) tag = alist[0] if tag == "and": - return pddl.ConjunctiveEffect( - [parse_effect(eff, type_dict, predicate_dict) for eff in alist[1:]]) + effects = [] + for eff in alist[1:]: + if not isinstance(eff, list): + context.error("All sub-effects of a conjunction have to be blocks.", + eff) + effects.append(parse_effect(context, eff, type_dict, predicate_dict)) + return pddl.ConjunctiveEffect(effects) elif tag == "forall": - assert len(alist) == 3 - parameters = parse_typed_list(alist[1]) - effect = parse_effect(alist[2], type_dict, predicate_dict) + if len(alist) != 3: + context.error("'forall' effect expects exactly two arguments.", + syntax=SYNTAX_EFFECT_FORALL) + if not isinstance(alist[1], list): + context.expected_list_error( + "First argument (VARIABLES) of 'forall'", + alist[1], syntax=SYNTAX_EFFECT_FORALL) + parameters = parse_typed_list(context, alist[1]) + if not isinstance(alist[2], list): + context.expected_list_error( + "Second argument (EFFECT) of 'forall'", + alist[2], syntax=SYNTAX_EFFECT_FORALL) + effect = parse_effect(context, alist[2], type_dict, predicate_dict) return pddl.UniversalEffect(parameters, effect) elif tag == "when": - assert len(alist) == 3 - condition = parse_condition( - alist[1], type_dict, predicate_dict) - effect = parse_effect(alist[2], type_dict, predicate_dict) + if len(alist) != 3: + context.error("'when' effect expects exactly two arguments.", + syntax=SYNTAX_EFFECT_WHEN) + if not isinstance(alist[1], list): + context.error( + "First argument (CONDITION) of 'when' is expected to be a block", + alist[1], syntax=SYNTAX_EFFECT_WHEN) + condition = parse_condition(context, alist[1], type_dict, predicate_dict) + if not isinstance(alist[2], list): + context.expected_list_error( + "Second argument (EFFECT) of 'when'", + alist[2], syntax=SYNTAX_EFFECT_WHEN) + effect = parse_effect(context, alist[2], type_dict, predicate_dict) return pddl.ConditionalEffect(condition, effect) elif tag == "increase": - assert len(alist) == 3 - assert alist[1] == ['total-cost'] - assignment = parse_assignment(alist) + if len(alist) != 3 or alist[1] != ["total-cost"]: + context.error("'increase' expects two arguments", + alist, syntax=SYNTAX_EFFECT_INCREASE) + assignment = parse_assignment(context, alist) return pddl.CostEffect(assignment) else: # We pass in {} instead of type_dict here because types must # be static predicates, so cannot be the target of an effect. - return pddl.SimpleEffect(parse_literal(alist, {}, predicate_dict)) - - -def parse_expression(exp): - if isinstance(exp, list): - functionsymbol = exp[0] - return pddl.PrimitiveNumericExpression(functionsymbol, exp[1:]) - elif exp.replace(".", "").isdigit(): - return pddl.NumericConstant(float(exp)) - elif exp[0] == "-": - raise ValueError("Negative numbers are not supported") - else: - return pddl.PrimitiveNumericExpression(exp, []) - -def parse_assignment(alist): - assert len(alist) == 3 - op = alist[0] - head = parse_expression(alist[1]) - exp = parse_expression(alist[2]) - if op == "=": - return pddl.Assign(head, exp) - elif op == "increase": - return pddl.Increase(head, exp) - else: - assert False, "Assignment operator not supported." - - -def parse_action(alist, type_dict, predicate_dict): - iterator = iter(alist) - action_tag = next(iterator) - assert action_tag == ":action" - name = next(iterator) - parameters_tag_opt = next(iterator) - if parameters_tag_opt == ":parameters": - parameters = parse_typed_list(next(iterator), - only_variables=True) - precondition_tag_opt = next(iterator) - else: - parameters = [] - precondition_tag_opt = parameters_tag_opt - if precondition_tag_opt == ":precondition": - precondition_list = next(iterator) - if not precondition_list: - # Note that :precondition () is allowed in PDDL. - precondition = pddl.Conjunction([]) + return pddl.SimpleEffect(parse_literal(context, alist, {}, predicate_dict)) + + +def parse_expression(context, exp): + with context.layer("Parsing expression"): + if isinstance(exp, list): + if len(exp) < 1: + context.error("Expression cannot be an empty block.", + syntax=SYNTAX_EXPRESSION) + functionsymbol = exp[0] + return pddl.PrimitiveNumericExpression(functionsymbol, exp[1:]) + elif exp.replace(".", "").isdigit() and exp.count(".") <= 1: + return pddl.NumericConstant(float(exp)) + elif exp[0] == "-": + context.error("Expression cannot be a negative number", + syntax=SYNTAX_EXPRESSION) else: - precondition = parse_condition( - precondition_list, type_dict, predicate_dict) - effect_tag = next(iterator) - else: - precondition = pddl.Conjunction([]) - effect_tag = precondition_tag_opt - assert effect_tag == ":effect" - effect_list = next(iterator) - eff = [] - if effect_list: + return pddl.PrimitiveNumericExpression(exp, []) + + +def parse_assignment(context, alist): + with context.layer("Parsing Assignment"): + if len(alist) != 3: + context.error("Assignment expects two arguments", + syntax=SYNTAX_ASSIGNMENT) + op = alist[0] + head = parse_expression(context, alist[1]) + exp = parse_expression(context, alist[2]) + if op == "=": + return pddl.Assign(head, exp) + elif op == "increase": + return pddl.Increase(head, exp) + else: + context.error(f"Unsupported assignment operator '{op}'." + f" Use '=' or 'increase'.") + + +def parse_action(context, alist, type_dict, predicate_dict): + with context.layer("Parsing action name"): + if len(alist) < 4: + context.error("Expecting block with at least 3 arguments for an action.", + syntax=SYNTAX_ACTION) + iterator = iter(alist) + action_tag = next(iterator) + assert action_tag == ":action" + name = next(iterator) + if not isinstance(name, str): + context.expected_word_error("Action name", name, syntax=SYNTAX_ACTION) + with context.layer(f"Parsing action '{name}'"): try: - cost = parse_effects( - effect_list, eff, type_dict, predicate_dict) - except ValueError as e: - raise SystemExit("Error in Action %s\nReason: %s." % (name, e)) - for rest in iterator: - assert False, rest + with context.layer("Parsing parameters"): + parameters_tag_opt = next(iterator) + if parameters_tag_opt == ":parameters": + parameters_list = next(iterator) + if not isinstance(parameters_list, list): + context.expected_list_error( + "Parameters", parameters_list, syntax=SYNTAX_ACTION) + parameters = parse_typed_list( + context, parameters_list, only_variables=True) + precondition_tag_opt = next(iterator) + else: + parameters = [] + precondition_tag_opt = parameters_tag_opt + with context.layer("Parsing precondition"): + if precondition_tag_opt == ":precondition": + precondition_list = next(iterator) + if not isinstance(precondition_list, list): + context.expected_list_error( + "Precondition", precondition_list, syntax=SYNTAX_ACTION) + if not precondition_list: + # Note that :precondition () is allowed in PDDL. + precondition = pddl.Conjunction([]) + else: + precondition = parse_condition( + context, precondition_list, type_dict, predicate_dict) + effect_tag = next(iterator) + else: + precondition = pddl.Conjunction([]) + effect_tag = precondition_tag_opt + with context.layer("Parsing effect"): + if effect_tag != ":effect": + context.error( + "Effect tag is expected to be ':effect'", effect_tag, + syntax=SYNTAX_ACTION) + effect_list = next(iterator) + if not isinstance(effect_list, list): + context.expected_list_error("Effect", effect_list, syntax=SYNTAX_ACTION) + eff = [] + if effect_list: + cost = parse_effects( + context, effect_list, eff, type_dict, predicate_dict) + except StopIteration: + context.error(f"Missing fields. Expecting {SYNTAX_ACTION}.") + for _ in iterator: + context.error(f"Too many fields. Expecting {SYNTAX_ACTION}") if eff: return pddl.Action(name, parameters, len(parameters), precondition, eff, cost) @@ -282,27 +533,117 @@ def parse_action(alist, type_dict, predicate_dict): return None -def parse_axiom(alist, type_dict, predicate_dict): - assert len(alist) == 3 - assert alist[0] == ":derived" - predicate = parse_predicate(alist[1]) - condition = parse_condition( - alist[2], type_dict, predicate_dict) - return pddl.Axiom(predicate.name, predicate.arguments, - len(predicate.arguments), condition) +def parse_axiom(context, alist, type_dict, predicate_dict): + with context.layer("Parsing derived predicate"): + if len(alist) != 3: + context.error("Expecting block with exactly three elements", + syntax=SYNTAX_AXIOM) + assert alist[0] == ":derived" + if not isinstance(alist[1], list): + context.expected_list_error("The first argument (PREDICATE)", + syntax=SYNTAX_AXIOM) + predicate = parse_predicate(context, alist[1]) + with context.layer(f"Parsing condition for derived predicate '{predicate}'"): + if not isinstance(alist[2], list): + context.error("The second argument (CONDITION) is expected to be a block.", + syntax=SYNTAX_AXIOM) + condition = parse_condition( + context, alist[2], type_dict, predicate_dict) + return pddl.Axiom(predicate.name, predicate.arguments, + len(predicate.arguments), condition) + +def parse_axioms_and_actions(context, entries, type_dict, predicate_dict): + the_axioms = [] + the_actions = [] + for no, entry in enumerate(entries, start=1): + with context.layer(f"Parsing {no}. axiom/action entry"): + assert_named_block(context, entry, [":derived", ":action"]) + if entry[0] == ":derived": + with context.layer(f"Parsing {len(the_axioms) + 1}. axiom"): + the_axioms.append(parse_axiom( + context, entry, type_dict, predicate_dict)) + else: + assert entry[0] == ":action" + with context.layer(f"Parsing {len(the_actions) + 1}. action"): + action = parse_action(context, entry, type_dict, predicate_dict) + if action is not None: + the_actions.append(action) + return the_axioms, the_actions + +def parse_init(context, alist): + initial = [] + initial_proposition_values = dict() + initial_assignments = dict() + for no, fact in enumerate(alist[1:], start=1): + with context.layer(f"Parsing {no}. element in init block"): + if not isinstance(fact, list) or not fact: + context.error( + "Invalid fact.", + syntax=f"{SYNTAX_LITERAL_POSSIBLY_NEGATED} or {SYNTAX_ASSIGNMENT}") + if fact[0] == "=": + try: + assignment = parse_assignment(context, fact) + except ValueError as e: + context.error(f"Error in initial state specification\n" + f"Reason: {e}.") + if not isinstance(assignment.expression, + pddl.NumericConstant): + context.error("Illegal assignment in initial state specification.", + assignment) + if assignment.fluent in initial_assignments: + prev = initial_assignments[assignment.fluent] + if assignment.expression == prev.expression: + print(f"Warning: {assignment} is specified twice " + f"in initial state specification") + else: + context.error("Error in initial state specification\n" + "Reason: conflicting assignment for " + f"{assignment.fluent}.") + else: + initial_assignments[assignment.fluent] = assignment + initial.append(assignment) + elif fact[0] == "not": + if len(fact) != 2: + context.error(f"Expecting {SYNTAX_LITERAL_NEGATED} for negated atoms.") + fact = fact[1] + if not isinstance(fact, list) or not fact: + context.error("Invalid negated fact.", syntax=SYNTAX_LITERAL_NEGATED) + atom = pddl.Atom(fact[0], fact[1:]) + check_atom_consistency(context, atom, + initial_proposition_values, False) + initial_proposition_values[atom] = False + else: + atom = pddl.Atom(fact[0], fact[1:]) + check_atom_consistency(context, atom, + initial_proposition_values, True) + initial_proposition_values[atom] = True + initial.extend(atom for atom, val in initial_proposition_values.items() + if val is True) + return initial -def parse_task(domain_pddl, task_pddl): - domain_name, domain_requirements, types, type_dict, constants, predicates, predicate_dict, functions, actions, axioms \ - = parse_domain_pddl(domain_pddl) - task_name, task_domain_name, task_requirements, objects, init, goal, use_metric = parse_task_pddl(task_pddl, type_dict, predicate_dict) - assert domain_name == task_domain_name +def parse_task(domain_pddl, task_pddl): + context = Context() + if not isinstance(domain_pddl, list): + context.error("Invalid definition of a PDDL domain.") + domain_name, domain_requirements, types, type_dict, constants, predicates, \ + predicate_dict, functions, actions, axioms = parse_domain_pddl(context, domain_pddl) + if not isinstance(task_pddl, list): + context.error("Invalid definition of a PDDL task.") + task_name, task_domain_name, task_requirements, objects, init, goal, \ + use_metric = parse_task_pddl(context, task_pddl, type_dict, predicate_dict) + + if domain_name != task_domain_name: + context.error(f"The domain name specified by the task " + f"({task_domain_name}) does not match the name specified " + f"by the domain file ({domain_name}).") requirements = pddl.Requirements(sorted(set( domain_requirements.requirements + task_requirements.requirements))) objects = constants + objects check_for_duplicates( + context, [o.name for o in objects], errmsg="error: duplicate object %r", finalmsg="please check :constants and :objects definitions") @@ -313,180 +654,168 @@ def parse_task(domain_pddl, task_pddl): predicates, functions, init, goal, actions, axioms, use_metric) -def parse_domain_pddl(domain_pddl): +def parse_domain_pddl(context, domain_pddl): iterator = iter(domain_pddl) - - define_tag = next(iterator) - assert define_tag == "define" - domain_line = next(iterator) - assert domain_line[0] == "domain" and len(domain_line) == 2 - yield domain_line[1] - - ## We allow an arbitrary order of the requirement, types, constants, - ## predicates and functions specification. The PDDL BNF is more strict on - ## this, so we print a warning if it is violated. - requirements = pddl.Requirements([":strips"]) - the_types = [pddl.Type("object")] - constants, the_predicates, the_functions = [], [], [] - correct_order = [":requirements", ":types", ":constants", ":predicates", - ":functions"] - seen_fields = [] - first_action = None - for opt in iterator: - field = opt[0] - if field not in correct_order: - first_action = opt - break - if field in seen_fields: - raise SystemExit("Error in domain specification\n" + - "Reason: two '%s' specifications." % field) - if (seen_fields and - correct_order.index(seen_fields[-1]) > correct_order.index(field)): - msg = "\nWarning: %s specification not allowed here (cf. PDDL BNF)" % field - print(msg, file=sys.stderr) - seen_fields.append(field) - if field == ":requirements": - requirements = pddl.Requirements(opt[1:]) - elif field == ":types": - the_types.extend(parse_typed_list( - opt[1:], constructor=pddl.Type)) - elif field == ":constants": - constants = parse_typed_list(opt[1:]) - elif field == ":predicates": - the_predicates = [parse_predicate(entry) - for entry in opt[1:]] - the_predicates += [pddl.Predicate("=", [ - pddl.TypedObject("?x", "object"), - pddl.TypedObject("?y", "object")])] - elif field == ":functions": - the_functions = parse_typed_list( - opt[1:], - constructor=parse_function, - default_type="number") - set_supertypes(the_types) - yield requirements - yield the_types - type_dict = {type.name: type for type in the_types} - yield type_dict - yield constants - yield the_predicates - predicate_dict = {pred.name: pred for pred in the_predicates} - yield predicate_dict - yield the_functions - - entries = [] - if first_action is not None: - entries.append(first_action) - entries.extend(iterator) - - the_axioms = [] - the_actions = [] - for entry in entries: - if entry[0] == ":derived": - axiom = parse_axiom(entry, type_dict, predicate_dict) - the_axioms.append(axiom) - else: - action = parse_action(entry, type_dict, predicate_dict) - if action is not None: - the_actions.append(action) - yield the_actions - yield the_axioms - -def parse_task_pddl(task_pddl, type_dict, predicate_dict): + with context.layer("Parsing domain"): + define_tag = next(iterator) + if define_tag != "define": + context.error(f"Domain definition expected to start with '(define '. Got '({define_tag}'") + + with context.layer("Parsing domain name"): + domain_line = next(iterator) + if (not check_named_block(domain_line, ["domain"]) or + len(domain_line) != 2 or not isinstance(domain_line[1], str)): + context.error("Invalid definition of domain name.", + syntax=SYNTAX_DOMAIN_DOMAIN_NAME) + yield domain_line[1] + + ## We allow an arbitrary order of the requirement, types, constants, + ## predicates and functions specification. The PDDL BNF is more strict on + ## this, so we print a warning if it is violated. + requirements = pddl.Requirements([":strips"]) + the_types = [pddl.Type("object")] + constants, the_predicates, the_functions = [], [], [] + correct_order = [":requirements", ":types", ":constants", ":predicates", + ":functions"] + action_or_axiom_block = [":derived", ":action"] + seen_fields = [] + first_action = None + for opt in iterator: + assert_named_block(context, opt, correct_order + action_or_axiom_block) + field = opt[0] + if field not in correct_order: + first_action = opt + break + if field in seen_fields: + context.error(f"Error in domain specification\n" + f"Reason: two '{field}' specifications.") + if (seen_fields and + correct_order.index(seen_fields[-1]) > correct_order.index(field)): + msg = f"\nWarning: {field} specification not allowed here (cf. PDDL BNF)" + print(msg, file=sys.stderr) + seen_fields.append(field) + if field == ":requirements": + requirements = parse_requirements(context, opt[1:]) + elif field == ":types": + with context.layer("Parsing types"): + the_types.extend(parse_typed_list( + context, opt[1:], constructor=construct_type)) + elif field == ":constants": + with context.layer("Parsing constants"): + constants = parse_typed_list(context, opt[1:]) + elif field == ":predicates": + the_predicates = parse_predicates(context, opt[1:]) + the_predicates += [pddl.Predicate("=", [ + pddl.TypedObject("?x", "object"), + pddl.TypedObject("?y", "object")])] + elif field == ":functions": + with context.layer("Parsing functions"): + the_functions = parse_typed_list( + context, opt[1:], + constructor=parse_function, + default_type="number") + set_supertypes(the_types) + yield requirements + yield the_types + type_dict = {type.name: type for type in the_types} + yield type_dict + yield constants + yield the_predicates + predicate_dict = {pred.name: pred for pred in the_predicates} + yield predicate_dict + yield the_functions + + entries = [] + if first_action is not None: + entries.append(first_action) + entries.extend(iterator) + + the_axioms, the_actions = parse_axioms_and_actions( + context, entries, type_dict, predicate_dict) + + yield the_actions + yield the_axioms + +def parse_task_pddl(context, task_pddl, type_dict, predicate_dict): iterator = iter(task_pddl) - - define_tag = next(iterator) - assert define_tag == "define" - problem_line = next(iterator) - assert problem_line[0] == "problem" and len(problem_line) == 2 - yield problem_line[1] - domain_line = next(iterator) - assert domain_line[0] == ":domain" and len(domain_line) == 2 - yield domain_line[1] - - requirements_opt = next(iterator) - if requirements_opt[0] == ":requirements": - requirements = requirements_opt[1:] - objects_opt = next(iterator) - else: - requirements = [] - objects_opt = requirements_opt - yield pddl.Requirements(requirements) - - if objects_opt[0] == ":objects": - yield parse_typed_list(objects_opt[1:]) - init = next(iterator) - else: - yield [] - init = objects_opt - - assert init[0] == ":init" - initial = [] - initial_true = set() - initial_false = set() - initial_assignments = dict() - for fact in init[1:]: - if fact[0] == "=": - try: - assignment = parse_assignment(fact) - except ValueError as e: - raise SystemExit("Error in initial state specification\n" + - "Reason: %s." % e) - if not isinstance(assignment.expression, - pddl.NumericConstant): - raise SystemExit("Illegal assignment in initial state " + - "specification:\n%s" % assignment) - if assignment.fluent in initial_assignments: - prev = initial_assignments[assignment.fluent] - if assignment.expression == prev.expression: - print("Warning: %s is specified twice" % assignment, - "in initial state specification") - else: - raise SystemExit("Error in initial state specification\n" + - "Reason: conflicting assignment for " + - "%s." % assignment.fluent) - else: - initial_assignments[assignment.fluent] = assignment - initial.append(assignment) - elif fact[0] == "not": - atom = pddl.Atom(fact[1][0], fact[1][1:]) - check_atom_consistency(atom, initial_false, initial_true, False) - initial_false.add(atom) + with context.layer("Parsing task"): + define_tag = next(iterator) + if define_tag != "define": + context.error("Task definition expected to start with '(define ") + + with context.layer("Parsing problem name"): + problem_line = next(iterator) + if (not check_named_block(problem_line, ["problem"]) or + len(problem_line) != 2 or not isinstance(problem_line[1], str)): + context.error("Invalid problem name definition", problem_line, + syntax=SYNTAX_TASK_PROBLEM_NAME) + yield problem_line[1] + + with context.layer("Parsing domain name"): + domain_line = next(iterator) + if (not check_named_block(domain_line, [":domain"]) or + len(domain_line) != 2 or not isinstance(domain_line[1], str)): + context.error("Invalid domain name definition", domain_line, + syntax=SYNTAX_TASK_DOMAIN_NAME) + yield domain_line[1] + + requirements_opt = next(iterator) + assert_named_block(context, requirements_opt, [":requirements", ":objects", ":init"]) + if requirements_opt[0] == ":requirements": + requirements = requirements_opt[1:] + objects_opt = next(iterator) else: - atom = pddl.Atom(fact[0], fact[1:]) - check_atom_consistency(atom, initial_true, initial_false) - initial_true.add(atom) - initial.extend(initial_true) - yield initial - - goal = next(iterator) - assert goal[0] == ":goal" and len(goal) == 2 - yield parse_condition(goal[1], type_dict, predicate_dict) - - use_metric = False - for entry in iterator: - if entry[0] == ":metric": - if entry[1] == "minimize" and entry[2][0] == "total-cost": - use_metric = True - else: - assert False, "Unknown metric." - yield use_metric - - for entry in iterator: - assert False, entry - - -def check_atom_consistency(atom, same_truth_value, other_truth_value, atom_is_true=True): - if atom in other_truth_value: - raise SystemExit("Error in initial state specification\n" + - "Reason: %s is true and false." % atom) - if atom in same_truth_value: - if not atom_is_true: - atom = atom.negate() - print("Warning: %s is specified twice in initial state specification" % atom) + requirements = [] + objects_opt = requirements_opt + yield parse_requirements(context, requirements) + + assert_named_block(context, objects_opt, [":objects", ":init"]) + if objects_opt[0] == ":objects": + with context.layer("Parsing objects"): + yield parse_typed_list(context, objects_opt[1:]) + init = next(iterator) + else: + yield [] + init = objects_opt + + assert_named_block(context, init, [":init"]) + yield parse_init(context, init) + + goal = next(iterator) + with context.layer("Parsing goal"): + if (not check_named_block(goal, [":goal"]) or + len(goal) != 2 or not isinstance(goal[1], list) or + not goal[1]): + context.error("Expected non-empty goal.", syntax=SYNTAX_GOAL) + yield parse_condition(context, goal[1], type_dict, predicate_dict) + + use_metric = False + for entry in iterator: + if isinstance(entry, list) and entry[0] == ":metric": + with context.layer("Parsing metric"): + if len(entry) != 3 or not isinstance(entry[2], list) or len(entry[2]) != 1 or entry[1] != "minimize" or entry[2][0] != "total-cost": + context.error("Invalid metric definition.", entry, syntax=SYNTAX_METRIC) + use_metric = True + yield use_metric + + for _ in iterator: + assert False, "This line should be unreachable" + + +def check_atom_consistency(context, atom, initial_proposition_values, + atom_value): + if atom in initial_proposition_values: + prev_value = initial_proposition_values[atom] + if prev_value != atom_value: + context.error(f"Error in initial state specification\n" + f"Reason: {atom} is true and false.") + else: + if atom_value is False: + atom = atom.negate() + print(f"Warning: {atom} is specified twice in initial state specification") -def check_for_duplicates(elements, errmsg, finalmsg): +def check_for_duplicates(context, elements, errmsg, finalmsg): seen = set() errors = [] for element in elements: @@ -495,4 +824,4 @@ def check_for_duplicates(elements, errmsg, finalmsg): else: seen.add(element) if errors: - raise SystemExit("\n".join(errors) + "\n" + finalmsg) + context.error("\n".join(errors) + "\n" + finalmsg) diff --git a/src/translate/pddl_parser/pddl_file.py b/src/translate/pddl_parser/pddl_file.py index 58850e402..9e8fc01f3 100644 --- a/src/translate/pddl_parser/pddl_file.py +++ b/src/translate/pddl_parser/pddl_file.py @@ -1,4 +1,5 @@ from . import lisp_parser +from . import parse_error from . import parsing_functions file_open = open @@ -14,10 +15,10 @@ def parse_pddl_file(type, filename): return lisp_parser.parse_nested_list(file_open(filename, encoding='ISO-8859-1')) except OSError as e: - raise SystemExit("Error: Could not read file: %s\nReason: %s." % + raise SystemExit("Error: Could not read file: %s\nReason: %s" % (e.filename, e)) - except lisp_parser.ParseError as e: - raise SystemExit("Error: Could not parse %s file: %s\nReason: %s." % + except parse_error.ParseError as e: + raise parse_error.ParseError("Error: Could not parse %s file: %s\nReason: %s" % (type, filename, e)) diff --git a/src/translate/pddl_to_prolog.py b/src/translate/pddl_to_prolog.py index fee70f7c3..7950451be 100755 --- a/src/translate/pddl_to_prolog.py +++ b/src/translate/pddl_to_prolog.py @@ -155,6 +155,10 @@ def translate_facts(prog, task): assert isinstance(fact, pddl.Atom) or isinstance(fact, pddl.Assign) if isinstance(fact, pddl.Atom): prog.add_fact(fact) + else: + # Add a fact to indicate that the primitive numeric expression in + # fact.fluent has been defined. + prog.add_fact(normalize.get_pne_definition_predicate(fact.fluent)) def translate(task): # Note: The function requires that the task has been normalized. diff --git a/src/translate/tools.py b/src/translate/tools.py index ad244b8e9..12ac84a9b 100644 --- a/src/translate/tools.py +++ b/src/translate/tools.py @@ -1,22 +1,3 @@ -def cartesian_product(sequences): - # TODO: Rename this. It's not good that we have two functions - # called "product" and "cartesian_product", of which "product" - # computes cartesian products, while "cartesian_product" does not. - - # This isn't actually a proper cartesian product because we - # concatenate lists, rather than forming sequences of atomic elements. - # We could probably also use something like - # map(itertools.chain, product(*sequences)) - # but that does not produce the same results - if not sequences: - yield [] - else: - temp = list(cartesian_product(sequences[1:])) - for item in sequences[0]: - for sequence in temp: - yield item + sequence - - def get_peak_memory_in_kb(): try: # This will only work on Linux systems. diff --git a/src/translate/translate.py b/src/translate/translate.py index 34623ce3d..58c5dd08d 100755 --- a/src/translate/translate.py +++ b/src/translate/translate.py @@ -50,6 +50,7 @@ def python_version_supported(): ## we only list codes that are used by the translator component of the planner. TRANSLATE_OUT_OF_MEMORY = 20 TRANSLATE_OUT_OF_TIME = 21 +TRANSLATE_INPUT_ERROR = 31 simplified_effect_condition_counter = 0 added_implied_precondition_counter = 0 @@ -584,12 +585,15 @@ def pddl_to_sas(task): elif goal_list is None: return unsolvable_sas_task("Trivially false goal") + negative_in_goal = set() for item in goal_list: assert isinstance(item, pddl.Literal) + if item.negated: + negative_in_goal.add(item.negate()) with timers.timing("Computing fact groups", block=True): groups, mutex_groups, translation_key = fact_groups.compute_groups( - task, atoms, reachable_action_params) + task, atoms, reachable_action_params, negative_in_goal) with timers.timing("Building STRIPS to SAS dictionary"): ranges, strips_to_sas = strips_to_sas_dictionary( @@ -800,3 +804,6 @@ def handle_sigxcpu(signum, stackframe): traceback.print_exc(file=sys.stdout) print("=" * 79) sys.exit(TRANSLATE_OUT_OF_MEMORY) + except pddl_parser.ParseError as e: + print(e) + sys.exit(TRANSLATE_INPUT_ERROR)