Skip to content

Commit

Permalink
change constructors in pdb code
Browse files Browse the repository at this point in the history
  • Loading branch information
FlorianPommerening committed Feb 2, 2024
1 parent d5ead08 commit b49e3ae
Show file tree
Hide file tree
Showing 36 changed files with 554 additions and 179 deletions.
36 changes: 26 additions & 10 deletions src/search/pdbs/canonical_pdbs_heuristic.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "canonical_pdbs_heuristic.h"

#include "dominance_pruning.h"
#include "pattern_generator.h"
#include "utils.h"

#include "../plugins/plugin.h"
Expand All @@ -10,15 +9,15 @@

#include <iostream>
#include <limits>
#include <memory>

using namespace std;

namespace pdbs {
static CanonicalPDBs get_canonical_pdbs_from_options(
const shared_ptr<AbstractTask> &task, const plugins::Options &opts, utils::LogProxy &log) {
shared_ptr<PatternCollectionGenerator> pattern_generator =
opts.get<shared_ptr<PatternCollectionGenerator>>("patterns");
const shared_ptr<AbstractTask> &task,
const shared_ptr<PatternCollectionGenerator> &pattern_generator,
double max_time_dominance_pruning,
utils::LogProxy &log) {
utils::Timer timer;
if (log.is_at_least_normal()) {
log << "Initializing canonical PDB heuristic..." << endl;
Expand All @@ -36,7 +35,6 @@ static CanonicalPDBs get_canonical_pdbs_from_options(
shared_ptr<vector<PatternClique>> pattern_cliques =
pattern_collection_info.get_pattern_cliques();

double max_time_dominance_pruning = opts.get<double>("max_time_dominance_pruning");
if (max_time_dominance_pruning > 0.0) {
int num_variables = TaskProxy(*task).get_variables().size();
/*
Expand All @@ -62,9 +60,16 @@ static CanonicalPDBs get_canonical_pdbs_from_options(
return CanonicalPDBs(pdbs, pattern_cliques);
}

CanonicalPDBsHeuristic::CanonicalPDBsHeuristic(const plugins::Options &opts)
: Heuristic(opts),
canonical_pdbs(get_canonical_pdbs_from_options(task, opts, log)) {
CanonicalPDBsHeuristic::CanonicalPDBsHeuristic(
const shared_ptr<PatternCollectionGenerator> &pattern_generator,
double max_time_dominance_pruning,
const shared_ptr<AbstractTask> &transform,
bool cache_estimates,
const string &name,
utils::Verbosity verbosity)
: Heuristic(transform, cache_estimates, name, verbosity),
canonical_pdbs(get_canonical_pdbs_from_options(
task, pattern_generator, max_time_dominance_pruning, log)) {
}

int CanonicalPDBsHeuristic::compute_heuristic(const State &ancestor_state) {
Expand Down Expand Up @@ -106,7 +111,7 @@ class CanonicalPDBsHeuristicFeature : public plugins::TypedFeature<Evaluator, Ca
"pattern generation method",
"systematic(1)");
add_canonical_pdbs_options_to_feature(*this);
Heuristic::add_options_to_feature(*this);
Heuristic::add_options_to_feature(*this, "cpdbs");

document_language_support("action costs", "supported");
document_language_support("conditional effects", "not supported");
Expand All @@ -117,6 +122,17 @@ class CanonicalPDBsHeuristicFeature : public plugins::TypedFeature<Evaluator, Ca
document_property("safe", "yes");
document_property("preferred operators", "no");
}

virtual shared_ptr<CanonicalPDBsHeuristic> create_component(
const plugins::Options &opts, const utils::Context &) const override {
return make_shared<CanonicalPDBsHeuristic>(
opts.get<shared_ptr<PatternCollectionGenerator>>("patterns"),
opts.get<double>("max_time_dominance_pruning"),
opts.get<shared_ptr<AbstractTask>>("transform"),
opts.get<bool>("cache_estimates"),
opts.get<string>("name"),
opts.get<utils::Verbosity>("verbosity"));
}
};

static plugins::FeaturePlugin<CanonicalPDBsHeuristicFeature> _plugin;
Expand Down
12 changes: 10 additions & 2 deletions src/search/pdbs/canonical_pdbs_heuristic.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#define PDBS_CANONICAL_PDBS_HEURISTIC_H

#include "canonical_pdbs.h"
#include "pattern_generator.h"

#include "../heuristic.h"

#include <memory>

namespace plugins {
class Feature;
}
Expand All @@ -18,8 +21,13 @@ class CanonicalPDBsHeuristic : public Heuristic {
virtual int compute_heuristic(const State &ancestor_state) override;

public:
explicit CanonicalPDBsHeuristic(const plugins::Options &opts);
virtual ~CanonicalPDBsHeuristic() = default;
CanonicalPDBsHeuristic(
const std::shared_ptr<PatternCollectionGenerator> &pattern_generator,
double max_time_dominance_pruning,
const std::shared_ptr<AbstractTask> &transform,
bool cache_estimates,
const std::string &name,
utils::Verbosity verbosity);
};

void add_canonical_pdbs_options_to_feature(plugins::Feature &feature);
Expand Down
21 changes: 17 additions & 4 deletions src/search/pdbs/pattern_collection_generator_combo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ using namespace std;

namespace pdbs {
PatternCollectionGeneratorCombo::PatternCollectionGeneratorCombo(
const plugins::Options &opts)
: PatternCollectionGenerator(opts), opts(opts) {
int max_states,
const string &name,
utils::Verbosity verbosity)
: PatternCollectionGenerator(name, verbosity), max_states(max_states),
sub_generator_name(name + "_nested_generator"), verbosity(verbosity) {
// TODO 1082 does it make sense to store name and verbosity here? At least name conflicts with the function name().
}

string PatternCollectionGeneratorCombo::name() const {
Expand All @@ -31,7 +35,7 @@ PatternCollectionInformation PatternCollectionGeneratorCombo::compute_patterns(
TaskProxy task_proxy(*task);
shared_ptr<PatternCollection> patterns = make_shared<PatternCollection>();

PatternGeneratorGreedy large_pattern_generator(opts);
PatternGeneratorGreedy large_pattern_generator(max_states, sub_generator_name, verbosity);
Pattern large_pattern = large_pattern_generator.generate(task).get_pattern();
set<int> used_vars(large_pattern.begin(), large_pattern.end());
patterns->push_back(move(large_pattern));
Expand All @@ -55,7 +59,16 @@ class PatternCollectionGeneratorComboFeature : public plugins::TypedFeature<Patt
"maximum abstraction size for combo strategy",
"1000000",
plugins::Bounds("1", "infinity"));
add_generator_options_to_feature(*this);
add_generator_options_to_feature(*this, "combo");
}

virtual shared_ptr<PatternCollectionGeneratorCombo> create_component(
const plugins::Options &opts, const utils::Context &) const override {
return make_shared<PatternCollectionGeneratorCombo>(
opts.get<int>("max_states"),
opts.get<string>("name"),
opts.get<utils::Verbosity>("verbosity")
);
}
};

Expand Down
10 changes: 7 additions & 3 deletions src/search/pdbs/pattern_collection_generator_combo.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ namespace pdbs {
/* Take one large pattern and then single-variable patterns for
all goal variables that are not in the large pattern. */
class PatternCollectionGeneratorCombo : public PatternCollectionGenerator {
plugins::Options opts;
int max_states;
std::string sub_generator_name;
utils::Verbosity verbosity;

virtual std::string name() const override;
virtual PatternCollectionInformation compute_patterns(
const std::shared_ptr<AbstractTask> &task) override;
public:
explicit PatternCollectionGeneratorCombo(const plugins::Options &opts);
virtual ~PatternCollectionGeneratorCombo() = default;
PatternCollectionGeneratorCombo(
int max_states,
const std::string &name,
utils::Verbosity verbosity);
};
}

Expand Down
35 changes: 27 additions & 8 deletions src/search/pdbs/pattern_collection_generator_disjoint_cegar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@ using namespace std;

namespace pdbs {
PatternCollectionGeneratorDisjointCegar::PatternCollectionGeneratorDisjointCegar(
const plugins::Options &opts)
: PatternCollectionGenerator(opts),
max_pdb_size(opts.get<int>("max_pdb_size")),
max_collection_size(opts.get<int>("max_collection_size")),
max_time(opts.get<double>("max_time")),
use_wildcard_plans(opts.get<bool>("use_wildcard_plans")),
rng(utils::parse_rng_from_options(opts)) {
int max_pdb_size,
int max_collection_size,
double max_time,
bool use_wildcard_plans,
int random_seed,
const string &name,
utils::Verbosity verbosity)
: PatternCollectionGenerator(name, verbosity),
max_pdb_size(max_pdb_size),
max_collection_size(max_collection_size),
max_time(max_time),
use_wildcard_plans(use_wildcard_plans),
rng(utils::get_rng(random_seed)) {
}

string PatternCollectionGeneratorDisjointCegar::name() const {
Expand Down Expand Up @@ -75,11 +81,24 @@ class PatternCollectionGeneratorDisjointCegarFeature : public plugins::TypedFeat
"infinity",
plugins::Bounds("0.0", "infinity"));
add_cegar_wildcard_option_to_feature(*this);
add_generator_options_to_feature(*this);
utils::add_rng_options(*this);
add_generator_options_to_feature(*this, "disjoint_cegar");

add_cegar_implementation_notes_to_feature(*this);
}

virtual shared_ptr<PatternCollectionGeneratorDisjointCegar> create_component(
const plugins::Options &opts, const utils::Context &) const override {
return make_shared<PatternCollectionGeneratorDisjointCegar>(
opts.get<int>("max_pdb_size"),
opts.get<int>("max_collection_size"),
opts.get<double>("max_time"),
opts.get<bool>("use_wildcard_plans"),
opts.get<int>("random_seed"),
opts.get<string>("name"),
opts.get<utils::Verbosity>("verbosity")
);
}
};

static plugins::FeaturePlugin<PatternCollectionGeneratorDisjointCegarFeature> _plugin;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@ class PatternCollectionGeneratorDisjointCegar : public PatternCollectionGenerato
virtual PatternCollectionInformation compute_patterns(
const std::shared_ptr<AbstractTask> &task) override;
public:
explicit PatternCollectionGeneratorDisjointCegar(const plugins::Options &opts);
PatternCollectionGeneratorDisjointCegar(
int max_pdb_size,
int max_collection_size,
double max_time,
bool use_wildcard_plans,
int random_seed,
const std::string &name,
utils::Verbosity verbosity);
};
}

Expand Down
39 changes: 30 additions & 9 deletions src/search/pdbs/pattern_collection_generator_genetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ using namespace std;

namespace pdbs {
PatternCollectionGeneratorGenetic::PatternCollectionGeneratorGenetic(
const plugins::Options &opts)
: PatternCollectionGenerator(opts),
pdb_max_size(opts.get<int>("pdb_max_size")),
num_collections(opts.get<int>("num_collections")),
num_episodes(opts.get<int>("num_episodes")),
mutation_probability(opts.get<double>("mutation_probability")),
disjoint_patterns(opts.get<bool>("disjoint")),
rng(utils::parse_rng_from_options(opts)) {
int pdb_max_size,
int num_collections,
int num_episodes,
double mutation_probability,
bool disjoint,
int random_seed,
const string &name,
utils::Verbosity verbosity)
: PatternCollectionGenerator(name, verbosity),
pdb_max_size(pdb_max_size),
num_collections(num_collections),
num_episodes(num_episodes),
mutation_probability(mutation_probability),
disjoint_patterns(disjoint),
rng(utils::get_rng(random_seed)) {
}

void PatternCollectionGeneratorGenetic::select(
Expand Down Expand Up @@ -345,7 +352,7 @@ class PatternCollectionGeneratorGeneticFeature : public plugins::TypedFeature<Pa
"fitness) if its patterns are not disjoint",
"false");
utils::add_rng_options(*this);
add_generator_options_to_feature(*this);
add_generator_options_to_feature(*this, "genetic");

document_note(
"Note",
Expand Down Expand Up @@ -387,6 +394,20 @@ class PatternCollectionGeneratorGeneticFeature : public plugins::TypedFeature<Pa
document_language_support("conditional effects", "not supported");
document_language_support("axioms", "not supported");
}

virtual shared_ptr<PatternCollectionGeneratorGenetic> create_component(
const plugins::Options &opts, const utils::Context &) const override {
return make_shared<PatternCollectionGeneratorGenetic>(
opts.get<int>("pdb_max_size"),
opts.get<int>("num_collections"),
opts.get<int>("num_episodes"),
opts.get<double>("mutation_probability"),
opts.get<bool>("disjoint"),
opts.get<int>("random_seed"),
opts.get<string>("name"),
opts.get<utils::Verbosity>("verbosity")
);
}
};

static plugins::FeaturePlugin<PatternCollectionGeneratorGeneticFeature> _plugin;
Expand Down
10 changes: 9 additions & 1 deletion src/search/pdbs/pattern_collection_generator_genetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,15 @@ class PatternCollectionGeneratorGenetic : public PatternCollectionGenerator {
virtual PatternCollectionInformation compute_patterns(
const std::shared_ptr<AbstractTask> &task) override;
public:
explicit PatternCollectionGeneratorGenetic(const plugins::Options &opts);
PatternCollectionGeneratorGenetic(
int pdb_max_size,
int num_collections,
int num_episodes,
double mutation_probability,
bool disjoint,
int random_seed,
const std::string &name,
utils::Verbosity verbosity);
};
}

Expand Down
Loading

0 comments on commit b49e3ae

Please sign in to comment.