Skip to content

Commit

Permalink
[main] make RandomNumberGenerator an optional argument for VariableOr…
Browse files Browse the repository at this point in the history
…derFinder
  • Loading branch information
Silvan Sievers committed Aug 6, 2021
1 parent 06bbdfc commit 8eba9ac
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 35 deletions.
4 changes: 2 additions & 2 deletions src/search/merge_and_shrink/merge_tree_factory_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ MergeTreeFactoryLinear::MergeTreeFactoryLinear(const options::Options &options)

unique_ptr<MergeTree> MergeTreeFactoryLinear::compute_merge_tree(
const TaskProxy &task_proxy) {
variable_order_finder::VariableOrderFinder vof(task_proxy, variable_order_type, *rng);
variable_order_finder::VariableOrderFinder vof(task_proxy, variable_order_type, rng);
MergeTreeNode *root = new MergeTreeNode(vof.next());
while (!vof.done()) {
MergeTreeNode *right_child = new MergeTreeNode(vof.next());
Expand Down Expand Up @@ -72,7 +72,7 @@ unique_ptr<MergeTree> MergeTreeFactoryLinear::compute_merge_tree(
skipping all indices not in indices_subset, because these have been set
to "used" above.
*/
variable_order_finder::VariableOrderFinder vof(task_proxy, variable_order_type, *rng);
variable_order_finder::VariableOrderFinder vof(task_proxy, variable_order_type, rng);

int next_var = vof.next();
int ts_index = var_to_ts_index[next_var];
Expand Down
8 changes: 2 additions & 6 deletions src/search/pdbs/pattern_collection_generator_combo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
#include "../task_proxy.h"

#include "../utils/logging.h"
#include "../utils/rng.h"
#include "../utils/rng_options.h"
#include "../utils/timer.h"

#include <iostream>
Expand All @@ -21,8 +19,7 @@ using namespace std;

namespace pdbs {
PatternCollectionGeneratorCombo::PatternCollectionGeneratorCombo(const Options &opts)
: max_states(opts.get<int>("max_states")),
rng(utils::parse_rng_from_options(opts)) {
: max_states(opts.get<int>("max_states")) {
}

PatternCollectionInformation PatternCollectionGeneratorCombo::generate(
Expand All @@ -32,7 +29,7 @@ PatternCollectionInformation PatternCollectionGeneratorCombo::generate(
TaskProxy task_proxy(*task);
shared_ptr<PatternCollection> patterns = make_shared<PatternCollection>();

PatternGeneratorGreedy large_pattern_generator(max_states, rng);
PatternGeneratorGreedy large_pattern_generator(max_states);
Pattern large_pattern = large_pattern_generator.generate(task).get_pattern();
set<int> used_vars(large_pattern.begin(), large_pattern.end());
patterns->push_back(move(large_pattern));
Expand All @@ -56,7 +53,6 @@ static shared_ptr<PatternCollectionGenerator> _parse(OptionParser &parser) {
"maximum abstraction size for combo strategy",
"1000000",
Bounds("1", "infinity"));
utils::add_rng_options(parser);

Options opts = parser.parse();
if (parser.dry_run())
Expand Down
5 changes: 0 additions & 5 deletions src/search/pdbs/pattern_collection_generator_combo.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,11 @@

#include "pattern_generator.h"

namespace utils {
class RandomNumberGenerator;
}

namespace pdbs {
/* Take one large pattern and then single-variable patterns for
all goal variables that are not in the large pattern. */
class PatternCollectionGeneratorCombo : public PatternCollectionGenerator {
int max_states;
std::shared_ptr<utils::RandomNumberGenerator> rng;
public:
explicit PatternCollectionGeneratorCombo(const options::Options &opts);
virtual ~PatternCollectionGeneratorCombo() = default;
Expand Down
14 changes: 4 additions & 10 deletions src/search/pdbs/pattern_generator_greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#include "../task_utils/variable_order_finder.h"
#include "../utils/logging.h"
#include "../utils/math.h"
#include "../utils/rng.h"
#include "../utils/rng_options.h"
#include "../utils/timer.h"

#include <iostream>
Expand All @@ -20,14 +18,11 @@ using namespace std;

namespace pdbs {
PatternGeneratorGreedy::PatternGeneratorGreedy(const Options &opts)
: PatternGeneratorGreedy(
opts.get<int>("max_states"),
utils::parse_rng_from_options(opts)) {
: PatternGeneratorGreedy(opts.get<int>("max_states")) {
}

PatternGeneratorGreedy::PatternGeneratorGreedy(
int max_states, const shared_ptr<utils::RandomNumberGenerator> &rng)
: max_states(max_states), rng(rng) {
PatternGeneratorGreedy::PatternGeneratorGreedy(int max_states)
: max_states(max_states) {
}

PatternInformation PatternGeneratorGreedy::generate(const shared_ptr<AbstractTask> &task) {
Expand All @@ -36,7 +31,7 @@ PatternInformation PatternGeneratorGreedy::generate(const shared_ptr<AbstractTas
TaskProxy task_proxy(*task);
Pattern pattern;
variable_order_finder::VariableOrderFinder order(
task_proxy, variable_order_finder::GOAL_CG_LEVEL, *rng);
task_proxy, variable_order_finder::GOAL_CG_LEVEL);
VariablesProxy variables = task_proxy.get_variables();

int size = 1;
Expand Down Expand Up @@ -66,7 +61,6 @@ static shared_ptr<PatternGenerator> _parse(OptionParser &parser) {
"maximal number of abstract states in the pattern database.",
"1000000",
Bounds("1", "infinity"));
utils::add_rng_options(parser);

Options opts = parser.parse();
if (parser.dry_run())
Expand Down
8 changes: 1 addition & 7 deletions src/search/pdbs/pattern_generator_greedy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@ namespace options {
class Options;
}

namespace utils {
class RandomNumberGenerator;
}

namespace pdbs {
class PatternGeneratorGreedy : public PatternGenerator {
int max_states;
std::shared_ptr<utils::RandomNumberGenerator> rng;
public:
explicit PatternGeneratorGreedy(const options::Options &opts);
explicit PatternGeneratorGreedy(
int max_states, const std::shared_ptr<utils::RandomNumberGenerator> &rng);
explicit PatternGeneratorGreedy(int max_states);
virtual ~PatternGeneratorGreedy() = default;

virtual PatternInformation generate(const std::shared_ptr<AbstractTask> &task) override;
Expand Down
9 changes: 7 additions & 2 deletions src/search/task_utils/variable_order_finder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using utils::ExitCode;
namespace variable_order_finder {
VariableOrderFinder::VariableOrderFinder(const TaskProxy &task_proxy,
VariableOrderType variable_order_type,
utils::RandomNumberGenerator &rng)
shared_ptr<utils::RandomNumberGenerator> rng)
: task_proxy(task_proxy),
variable_order_type(variable_order_type) {
int var_count = task_proxy.get_variables().size();
Expand All @@ -31,7 +31,12 @@ VariableOrderFinder::VariableOrderFinder(const TaskProxy &task_proxy,

if (variable_order_type == CG_GOAL_RANDOM ||
variable_order_type == RANDOM) {
rng.shuffle(remaining_vars);
if (!rng) {
ABORT("No random number generator passed to VariableOrderFinder "
"although the chosen value for VariableOrderType relies on "
"randomization");
}
rng->shuffle(remaining_vars);
}

is_causal_predecessor.resize(var_count, false);
Expand Down
7 changes: 4 additions & 3 deletions src/search/task_utils/variable_order_finder.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class VariableOrderFinder {

void select_next(int position, int var_no);
public:
VariableOrderFinder(const TaskProxy &task_proxy,
VariableOrderType variable_order_type,
utils::RandomNumberGenerator &rng);
VariableOrderFinder(
const TaskProxy &task_proxy,
VariableOrderType variable_order_type,
std::shared_ptr<utils::RandomNumberGenerator> rng = nullptr);
~VariableOrderFinder() = default;
bool done() const;
int next();
Expand Down

0 comments on commit 8eba9ac

Please sign in to comment.