Skip to content

Commit

Permalink
fix novelty table resize
Browse files Browse the repository at this point in the history
  • Loading branch information
drexlerd committed Oct 20, 2023
1 parent e65710a commit bb52c02
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 5 deletions.
5 changes: 3 additions & 2 deletions api/python/src/dlplan/novelty/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ from ..state_space import StateSpace

class NoveltyBase:
def __init__(self, num_atoms: int, arity: int) -> None: ...
def atom_tuple_to_tuple_index(self, tuple_atom_indices: List[int]) -> int: ...
def tuple_index_to_atom_tuple(self, tuple_index: int) -> List[int]: ...
def atom_indices_to_tuple_index(self, atom_indices: List[int]) -> int: ...
def tuple_index_to_atom_indices(self, tuple_index: int) -> List[int]: ...
def get_num_atoms(self) -> int: ...
def get_arity(self) -> int: ...

Expand All @@ -24,6 +24,7 @@ class NoveltyTable:
@overload
def insert_tuple_indices(self, tuple_indices: List[int], stop_if_novel: bool = False) -> bool: ...
def resize(self, novelty_base: NoveltyBase) -> None: ...
def get_novelty_base(self) -> NoveltyBase: ...


class TupleNode:
Expand Down
1 change: 1 addition & 0 deletions api/python/src/novelty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void init_novelty(py::module_ &m_novelty) {
.def("insert_atom_indices", py::overload_cast<const AtomIndices&, const AtomIndices&, bool>(&NoveltyTable::insert_atom_indices), py::arg("atom_indices"), py::arg("add_atom_indices"), py::arg("stop_if_novel") = false)
.def("insert_tuple_indices", py::overload_cast<const TupleIndices&, bool>(&NoveltyTable::insert_tuple_indices), py::arg("tuple_indices"), py::arg("stop_if_novel") = false)
.def("resize", &NoveltyTable::resize)
.def("get_novelty_base", &NoveltyTable::get_novelty_base)
;

py::class_<TupleNode, std::shared_ptr<TupleNode>>(m_novelty, "TupleNode")
Expand Down
2 changes: 2 additions & 0 deletions include/dlplan/novelty.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class NoveltyTable

/// @brief Resizes the novelty table.
void resize(std::shared_ptr<const NoveltyBase> novelty_base);

const std::shared_ptr<const NoveltyBase> get_novelty_base() const;
};


Expand Down
2 changes: 1 addition & 1 deletion src/novelty/novelty_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <boost/serialization/vector.hpp>

#include "src/utils/math.h"
#include "src/utils/logging.h"


namespace dlplan::novelty {
Expand Down Expand Up @@ -38,7 +39,6 @@ NoveltyBase& NoveltyBase::operator=(NoveltyBase&& other) = default;
NoveltyBase::~NoveltyBase() = default;

TupleIndex NoveltyBase::atom_indices_to_tuple_index(const AtomIndices& atom_indices) const {
assert(static_cast<int>(atom_indices.size()) == m_arity);
assert(std::is_sorted(atom_indices.begin(), atom_indices.end()));
TupleIndex result = 0;
int i = 0;
Expand Down
19 changes: 17 additions & 2 deletions src/novelty/novelty_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,23 @@ void NoveltyTable::resize(std::shared_ptr<const NoveltyBase> novelty_base) {
if (novelty_base->get_arity() != m_novelty_base->get_arity()) {
throw std::runtime_error("NoveltyTable::resize - missmatched arity of novelty_table and novelty_base.");
}
m_table.resize(std::pow(novelty_base->get_num_atoms()+1, novelty_base->get_arity()), true);
m_novelty_base = novelty_base;
NoveltyTable new_table(novelty_base);
// mark tuples in new table
AtomIndices atom_indices;
int new_tuple_index;
for (int old_tuple_index = 0; old_tuple_index < m_table.size(); ++old_tuple_index) {
if (!m_table[old_tuple_index]) {
atom_indices = m_novelty_base->tuple_index_to_atom_indices(old_tuple_index);
new_tuple_index = novelty_base->atom_indices_to_tuple_index(atom_indices);
new_table.m_table[new_tuple_index] = false;
}
}
m_table = std::move(new_table.m_table);
m_novelty_base = std::move(novelty_base);
}

const std::shared_ptr<const NoveltyBase> NoveltyTable::get_novelty_base() const {
return m_novelty_base;
}

}

0 comments on commit bb52c02

Please sign in to comment.