From bb52c02f9d836a2b497e109848ec78730c726e1b Mon Sep 17 00:00:00 2001 From: Dominik Drexler Date: Fri, 20 Oct 2023 13:39:49 +0200 Subject: [PATCH] fix novelty table resize --- api/python/src/dlplan/novelty/__init__.pyi | 5 +++-- api/python/src/novelty.cpp | 1 + include/dlplan/novelty.h | 2 ++ src/novelty/novelty_base.cpp | 2 +- src/novelty/novelty_table.cpp | 19 +++++++++++++++++-- 5 files changed, 24 insertions(+), 5 deletions(-) diff --git a/api/python/src/dlplan/novelty/__init__.pyi b/api/python/src/dlplan/novelty/__init__.pyi index fa835f95..249ec78d 100644 --- a/api/python/src/dlplan/novelty/__init__.pyi +++ b/api/python/src/dlplan/novelty/__init__.pyi @@ -5,8 +5,8 @@ from ..state_space import StateSpace class NoveltyBase: def __init__(self, num_atoms: int, arity: int) -> None: ... - def atom_tuple_to_tuple_index(self, tuple_atom_indices: List[int]) -> int: ... - def tuple_index_to_atom_tuple(self, tuple_index: int) -> List[int]: ... + def atom_indices_to_tuple_index(self, atom_indices: List[int]) -> int: ... + def tuple_index_to_atom_indices(self, tuple_index: int) -> List[int]: ... def get_num_atoms(self) -> int: ... def get_arity(self) -> int: ... @@ -24,6 +24,7 @@ class NoveltyTable: @overload def insert_tuple_indices(self, tuple_indices: List[int], stop_if_novel: bool = False) -> bool: ... def resize(self, novelty_base: NoveltyBase) -> None: ... + def get_novelty_base(self) -> NoveltyBase: ... class TupleNode: diff --git a/api/python/src/novelty.cpp b/api/python/src/novelty.cpp index e8772c32..ae0483bf 100644 --- a/api/python/src/novelty.cpp +++ b/api/python/src/novelty.cpp @@ -29,6 +29,7 @@ void init_novelty(py::module_ &m_novelty) { .def("insert_atom_indices", py::overload_cast(&NoveltyTable::insert_atom_indices), py::arg("atom_indices"), py::arg("add_atom_indices"), py::arg("stop_if_novel") = false) .def("insert_tuple_indices", py::overload_cast(&NoveltyTable::insert_tuple_indices), py::arg("tuple_indices"), py::arg("stop_if_novel") = false) .def("resize", &NoveltyTable::resize) + .def("get_novelty_base", &NoveltyTable::get_novelty_base) ; py::class_>(m_novelty, "TupleNode") diff --git a/include/dlplan/novelty.h b/include/dlplan/novelty.h index 9f5d7707..44f5b743 100644 --- a/include/dlplan/novelty.h +++ b/include/dlplan/novelty.h @@ -164,6 +164,8 @@ class NoveltyTable /// @brief Resizes the novelty table. void resize(std::shared_ptr novelty_base); + + const std::shared_ptr get_novelty_base() const; }; diff --git a/src/novelty/novelty_base.cpp b/src/novelty/novelty_base.cpp index 532526d6..a02a99a0 100644 --- a/src/novelty/novelty_base.cpp +++ b/src/novelty/novelty_base.cpp @@ -11,6 +11,7 @@ #include #include "src/utils/math.h" +#include "src/utils/logging.h" namespace dlplan::novelty { @@ -38,7 +39,6 @@ NoveltyBase& NoveltyBase::operator=(NoveltyBase&& other) = default; NoveltyBase::~NoveltyBase() = default; TupleIndex NoveltyBase::atom_indices_to_tuple_index(const AtomIndices& atom_indices) const { - assert(static_cast(atom_indices.size()) == m_arity); assert(std::is_sorted(atom_indices.begin(), atom_indices.end())); TupleIndex result = 0; int i = 0; diff --git a/src/novelty/novelty_table.cpp b/src/novelty/novelty_table.cpp index 760fd3a0..2e744893 100644 --- a/src/novelty/novelty_table.cpp +++ b/src/novelty/novelty_table.cpp @@ -115,8 +115,23 @@ void NoveltyTable::resize(std::shared_ptr novelty_base) { if (novelty_base->get_arity() != m_novelty_base->get_arity()) { throw std::runtime_error("NoveltyTable::resize - missmatched arity of novelty_table and novelty_base."); } - m_table.resize(std::pow(novelty_base->get_num_atoms()+1, novelty_base->get_arity()), true); - m_novelty_base = novelty_base; + NoveltyTable new_table(novelty_base); + // mark tuples in new table + AtomIndices atom_indices; + int new_tuple_index; + for (int old_tuple_index = 0; old_tuple_index < m_table.size(); ++old_tuple_index) { + if (!m_table[old_tuple_index]) { + atom_indices = m_novelty_base->tuple_index_to_atom_indices(old_tuple_index); + new_tuple_index = novelty_base->atom_indices_to_tuple_index(atom_indices); + new_table.m_table[new_tuple_index] = false; + } + } + m_table = std::move(new_table.m_table); + m_novelty_base = std::move(novelty_base); +} + +const std::shared_ptr NoveltyTable::get_novelty_base() const { + return m_novelty_base; } } \ No newline at end of file