Skip to content

Commit

Permalink
refactor policy code by making use of Base, added BaseVisitor for con…
Browse files Browse the repository at this point in the history
…ditions and effects
  • Loading branch information
drexlerd committed Nov 23, 2023
1 parent dd0f2dd commit a0387b6
Show file tree
Hide file tree
Showing 20 changed files with 402 additions and 908 deletions.
20 changes: 4 additions & 16 deletions api/python/src/dlplan/policy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,59 +5,48 @@ from ..core import State, DenotationsCaches, Boolean, Numerical, Concept, Role,


class NamedBoolean():
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def get_key(self) -> str: ...
def get_boolean(self) -> Boolean: ...
def get_element(self) -> Boolean: ...


class NamedNumerical():
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def get_key(self) -> str: ...
def get_numerical(self) -> Numerical: ...
def get_element(self) -> Numerical: ...


class NamedConcept():
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def get_key(self) -> str: ...
def get_concept(self) -> Concept: ...
def get_element(self) -> Concept: ...


class NamedRole():
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
def get_key(self) -> str: ...
def get_role(self) -> Role: ...
def get_element(self) -> Role: ...


class BaseCondition:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
@overload
def evaluate(self, state: State) -> bool: ...
@overload
def evaluate(self, state: State, caches: DenotationsCaches) -> bool: ...
def get_index(self) -> int: ...
def get_boolean(self) -> Union[None, Boolean]: ...
def get_numerical(self) -> Union[None, Numerical]: ...


class BaseEffect:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
@overload
def evaluate(self, source_state: State, target_state: State) -> bool: ...
@overload
def evaluate(self, source_state: State, target_state: State, caches: DenotationsCaches) -> bool: ...
def get_index(self) -> int: ...
def get_boolean(self) -> Union[None, Boolean]: ...
def get_numerical(self) -> Union[None, Numerical]: ...


class Rule:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
@overload
def evaluate_conditions(self, state: State) -> bool: ...
Expand All @@ -73,7 +62,6 @@ class Rule:


class Policy:
def __repr__(self) -> str: ...
def __str__(self) -> str: ...
@overload
def evaluate(self, source_state: State, target_state: State) -> Union[None, Rule]: ...
Expand Down
36 changes: 12 additions & 24 deletions api/python/src/policy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,45 @@ using namespace dlplan;

void init_policy(py::module_ &m_policy) {
py::class_<policy::NamedBoolean, std::shared_ptr<policy::NamedBoolean>>(m_policy, "NamedBoolean")
.def("__repr__", &policy::NamedBoolean::compute_repr)
.def("__str__", &policy::NamedBoolean::str)
.def("__str__", py::overload_cast<>(&policy::NamedBoolean::str, py::const_))
.def("get_key", &policy::NamedBoolean::get_key)
.def("get_boolean", &policy::NamedBoolean::get_boolean)
.def("get_element", &policy::NamedBoolean::get_element)
;

py::class_<policy::NamedNumerical, std::shared_ptr<policy::NamedNumerical>>(m_policy, "NamedNumerical")
.def("__repr__", &policy::NamedNumerical::compute_repr)
.def("__str__", &policy::NamedNumerical::str)
.def("__str__", py::overload_cast<>(&policy::NamedNumerical::str, py::const_))
.def("get_key", &policy::NamedNumerical::get_key)
.def("get_numerical", &policy::NamedNumerical::get_numerical)
.def("get_element", &policy::NamedNumerical::get_element)
;

py::class_<policy::NamedConcept, std::shared_ptr<policy::NamedConcept>>(m_policy, "NamedConcept")
.def("__repr__", &policy::NamedConcept::compute_repr)
.def("__str__", &policy::NamedConcept::str)
.def("__str__", py::overload_cast<>(&policy::NamedConcept::str, py::const_))
.def("get_key", &policy::NamedConcept::get_key)
.def("get_concept", &policy::NamedConcept::get_concept)
.def("get_element", &policy::NamedConcept::get_element)
;

py::class_<policy::NamedRole, std::shared_ptr<policy::NamedRole>>(m_policy, "NamedRole")
.def("__repr__", &policy::NamedRole::compute_repr)
.def("__str__", &policy::NamedRole::str)
.def("__str__", py::overload_cast<>(&policy::NamedRole::str, py::const_))
.def("get_key", &policy::NamedRole::get_key)
.def("get_role", &policy::NamedRole::get_role)
.def("get_element", &policy::NamedRole::get_element)
;

py::class_<policy::BaseCondition, std::shared_ptr<policy::BaseCondition>>(m_policy, "BaseCondition")
.def("__repr__", &policy::BaseCondition::compute_repr)
.def("__str__", &policy::BaseCondition::str)
.def("__str__", py::overload_cast<>(&policy::BaseCondition::str, py::const_))
.def("evaluate", py::overload_cast<const core::State&>(&policy::BaseCondition::evaluate, py::const_))
.def("evaluate", py::overload_cast<const core::State&, core::DenotationsCaches&>(&policy::BaseCondition::evaluate, py::const_))
.def("get_index", &policy::BaseCondition::get_index)
.def("get_boolean", &policy::BaseCondition::get_boolean)
.def("get_numerical", &policy::BaseCondition::get_numerical)
;

py::class_<policy::BaseEffect, std::shared_ptr<policy::BaseEffect>>(m_policy, "BaseEffect")
.def("__repr__", &policy::BaseEffect::compute_repr)
.def("__str__", &policy::BaseEffect::str)
.def("__str__", py::overload_cast<>(&policy::BaseEffect::str, py::const_))
.def("evaluate", py::overload_cast<const core::State&, const core::State&>(&policy::BaseEffect::evaluate, py::const_))
.def("evaluate", py::overload_cast<const core::State&, const core::State&, core::DenotationsCaches&>(&policy::BaseEffect::evaluate, py::const_))
.def("get_index", &policy::BaseEffect::get_index)
.def("get_boolean", &policy::BaseEffect::get_boolean)
.def("get_numerical", &policy::BaseEffect::get_numerical)
;

py::class_<policy::Rule, std::shared_ptr<policy::Rule>>(m_policy, "Rule")
.def("__repr__", &policy::Rule::compute_repr)
.def("__str__", &policy::Rule::str)
.def("__str__", py::overload_cast<>(&policy::Rule::str, py::const_))
.def("evaluate_conditions", py::overload_cast<const core::State&>(&policy::Rule::evaluate_conditions, py::const_))
.def("evaluate_conditions", py::overload_cast<const core::State&, core::DenotationsCaches&>(&policy::Rule::evaluate_conditions, py::const_))
.def("evaluate_effects", py::overload_cast<const core::State&, const core::State&>(&policy::Rule::evaluate_effects, py::const_))
Expand All @@ -76,8 +65,7 @@ void init_policy(py::module_ &m_policy) {
;

py::class_<policy::Policy, std::shared_ptr<policy::Policy>>(m_policy, "Policy")
.def("__repr__", &policy::Policy::compute_repr)
.def("__str__", &policy::Policy::str)
.def("__str__", py::overload_cast<>(&policy::Policy::str, py::const_))
.def("evaluate", py::overload_cast<const core::State&, const core::State&>(&policy::Policy::evaluate, py::const_))
.def("evaluate", py::overload_cast<const core::State&, const core::State&, core::DenotationsCaches&>(&policy::Policy::evaluate, py::const_))
.def("evaluate_conditions", py::overload_cast<const core::State&>(&policy::Policy::evaluate_conditions, py::const_))
Expand Down
2 changes: 0 additions & 2 deletions examples/policy/policy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ int main() {
assert(!policy->evaluate(state_2, state_0, caches));
assert(!policy->evaluate(state_2, state_1, caches));

std::cout << policy->compute_repr() << std::endl << std::endl;
std::cout << policy->str() << std::endl << std::endl;

std::cout << "Parsing policy:" << std::endl;
Expand All @@ -110,7 +109,6 @@ int main() {
"(:rule (:conditions (:c_b_pos b0) (:c_n_gt n0)) (:effects (:e_b_bot b0) (:e_n_dec n0)))\n"
")";
auto policy_in = policy_factory.parse_policy(policy_str);
std::cout << policy_in->compute_repr() << std::endl << std::endl;
std::cout << policy_in->str() << std::endl << std::endl;

return 0;
Expand Down
2 changes: 1 addition & 1 deletion include/dlplan/common/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Base {

/// @brief Compute a string representation of this object.
void str(std::stringstream& out) const {
return static_cast<const Derived*>(this)->str_impl(out);
static_cast<const Derived*>(this)->str_impl(out);
}

/// @brief Compute a string representation of this object.
Expand Down
Loading

0 comments on commit a0387b6

Please sign in to comment.