Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

string parser for downard #2

Merged
merged 6 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,31 @@ python3 run_gnn.py <DOMAIN_PDDL> <TASK_PDDL> -m <WEIGHTS_FILE> -r <REPRESENTATIO

### Training
#### Loading the training dataset
Requires access to `plan_objects.zip`. Also requires packages in `requirements.txt` or alternatively use the singularity
container as in [Search Evaluation](#search-evaluation). Perform the following steps
- enter the ```learner``` directory
- create ```data``` directory in the ```learner``` directory
- unzip ```plan_objects.zip``` and put into ```data``` (there should now be a directory
```path_to_goose/learner/data/plan_objects```)
- run the following while in the ```learner``` directory:
Requires access to `plan_objects.zip`. Also requires packages in `requirements.txt` using for example a virtual environment
and `pip install -r requirements.txt`, or alternatively use the singularity container as in [Search](#search). Perform the
following steps
- enter the `learner` directory
- create `data` directory in the `learner` directory
- unzip `plan_objects.zip` and put into `data` (there should now be a directory
`path_to_goose/learner/data/plan_objects`)
- run the following while in the `learner` directory:
```
python3 scripts/generate_graphs.py llg
python3 generate_graphs_gnn.py --regenerate <REPRESENTATION>
```
for <REPRESENTATION> from `llg, dlg, slg, glg, flg` or generate them all at once with
```
sh dataset/generate_all_graphs_gnn.sh
```

#### Domain-dependent training
Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search
Evaluation](#search-evaluation). To train, go into ```learner``` directory (`cd learner`). Then run
Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search](#search). To train, go
into ```learner``` directory (`cd learner`) and run
```
python3 train_gnn.py -m RGNN -r llg -d goose-<DOMAIN>-only --save-file <SAVE_FILE>
```
where you replace ```<DOMAIN>``` by any domain from ```blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall,
visitsome``` and ```<SAVE_FILE>``` is the name of the save file ending in `.dt` for the trained weights of the models which
would then be located in ```trained_models/<SAVE_FILE>``` after training.
where you replace `<DOMAIN>` by any domain from `blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall,
visitsome` and `<SAVE_FILE>` is the name of the save file ending in `.dt` for the trained weights of the models which
would then be located in `trained_models/<SAVE_FILE>` after training.

## Kernels
### Search
Expand Down
80 changes: 17 additions & 63 deletions downward/src/search/heuristics/goose_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ using std::string;
namespace goose_heuristic {
GooseHeuristic::GooseHeuristic(const plugins::Options &opts)
: Heuristic(opts) {

initialise_model(opts);
initialise_fact_strings();

}

void GooseHeuristic::initialise_model(const plugins::Options &opts) {
Expand All @@ -42,46 +40,10 @@ void GooseHeuristic::initialise_model(const plugins::Options &opts) {
// python will be printed to stderr, even if it is not an error.
sys.attr("stderr") = sys.attr("stdout");

// A really disgusting hack because FeaturePlugin cannot parse string options
std::string config_path;
switch (opts.get<int>("graph"))
{
case 0: config_path = "slg"; break;
case 1: config_path = "flg"; break;
case 2: config_path = "dlg"; break;
case 3: config_path = "llg"; break;
default:
std::cout << "Unknown enum of graph representation" << std::endl;
exit(-1);
}

// Parse paths from file at config_path
std::string model_path;
std::string domain_file;
std::string instance_file;

std::string line;
std::ifstream config_file(config_path);
int file_line = 0;
while (getline(config_file, line)) {
switch (file_line) {
case 0:
model_path = line;
break;
case 1:
domain_file = line;
break;
case 2:
instance_file = line;
break;
default:
std::cout << "config file " << config_path
<< " must only have 3 lines" << std::endl;
exit(-1);
}
file_line++;
}
config_file.close();
// Read paths
std::string model_path = opts.get<string>("model_path");
std::string domain_file = opts.get<string>("domain_file");
std::string instance_file = opts.get<string>("instance_file");

// Throw everything into Python code
std::cout << "Trying to load model from file " << model_path << " ...\n";
Expand Down Expand Up @@ -187,27 +149,19 @@ class GooseHeuristicFeature : public plugins::TypedFeature<Evaluator, GooseHeuri
document_title("GOOSE heuristic");
document_synopsis("TODO");

add_option<int>(
"graph",
"0: slg, 1: flg, 2: llg, 3: glg",
"-1");

// add_option does not work with <string>

// add_option<string>(
// "model_path",
// "path to trained model weights of file type .dt",
// "default_value.dt");

// add_option<string>(
// "domain_file",
// "Path to the domain file.",
// "default_file.pddl");

// add_option<string>(
// "instance_file",
// "Path to the instance file.",
// "default_file.pddl");
// https://github.com/aibasel/downward/pull/170 for string options
add_option<string>(
"model_path",
"path to trained model weights of file type .dt",
"default_value.dt");
add_option<string>(
"domain_file",
"Path to the domain file.",
"default_file.pddl");
add_option<string>(
"instance_file",
"Path to the instance file.",
"default_file.pddl");

Heuristic::add_options_to_feature(*this);

Expand Down
6 changes: 5 additions & 1 deletion downward/src/search/parser/abstract_syntax_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ DecoratedASTNodePtr LiteralNode::decorate(DecorateContext &context) const {
switch (value.type) {
case TokenType::BOOLEAN:
return utils::make_unique_ptr<BoolLiteralNode>(value.content);
case TokenType::STRING:
return utils::make_unique_ptr<StringLiteralNode>(value.content);
case TokenType::INTEGER:
return utils::make_unique_ptr<IntLiteralNode>(value.content);
case TokenType::FLOAT:
Expand All @@ -440,6 +442,8 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const {
switch (value.type) {
case TokenType::BOOLEAN:
return plugins::TypeRegistry::instance()->get_type<bool>();
case TokenType::STRING:
return plugins::TypeRegistry::instance()->get_type<string>();
case TokenType::INTEGER:
return plugins::TypeRegistry::instance()->get_type<int>();
case TokenType::FLOAT:
Expand All @@ -454,4 +458,4 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const {
token_type_name(value.type) + "'.");
}
}
}
}
27 changes: 26 additions & 1 deletion downward/src/search/parser/decorated_abstract_syntax_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,19 @@ void BoolLiteralNode::dump(string indent) const {
cout << indent << "BOOL: " << value << endl;
}

StringLiteralNode::StringLiteralNode(const string &value)
: value(value) {
}

plugins::Any StringLiteralNode::construct(ConstructContext &context) const {
utils::TraceBlock block(context, "Constructing string value from '" + value + "'");
return value;
}

void StringLiteralNode::dump(string indent) const {
cout << indent << "STRING: " << value << endl;
}

IntLiteralNode::IntLiteralNode(const string &value)
: value(value) {
}
Expand Down Expand Up @@ -473,6 +486,18 @@ shared_ptr<DecoratedASTNode> BoolLiteralNode::clone_shared() const {
return make_shared<BoolLiteralNode>(*this);
}

StringLiteralNode::StringLiteralNode(const StringLiteralNode &other)
: value(other.value) {
}

unique_ptr<DecoratedASTNode> StringLiteralNode::clone() const {
return utils::make_unique_ptr<StringLiteralNode>(*this);
}

shared_ptr<DecoratedASTNode> StringLiteralNode::clone_shared() const {
return make_shared<StringLiteralNode>(*this);
}

IntLiteralNode::IntLiteralNode(const IntLiteralNode &other)
: value(other.value) {
}
Expand Down Expand Up @@ -534,4 +559,4 @@ unique_ptr<DecoratedASTNode> CheckBoundsNode::clone() const {
shared_ptr<DecoratedASTNode> CheckBoundsNode::clone_shared() const {
return make_shared<CheckBoundsNode>(*this);
}
}
}
16 changes: 15 additions & 1 deletion downward/src/search/parser/decorated_abstract_syntax_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,20 @@ class BoolLiteralNode : public DecoratedASTNode {
BoolLiteralNode(const BoolLiteralNode &other);
};

class StringLiteralNode : public DecoratedASTNode {
std::string value;
public:
StringLiteralNode(const std::string &value);

plugins::Any construct(ConstructContext &context) const override;
void dump(std::string indent) const override;

// TODO: once we get rid of lazy construction, this should no longer be necessary.
virtual std::unique_ptr<DecoratedASTNode> clone() const override;
virtual std::shared_ptr<DecoratedASTNode> clone_shared() const override;
StringLiteralNode(const StringLiteralNode &other);
};

class IntLiteralNode : public DecoratedASTNode {
std::string value;
public:
Expand Down Expand Up @@ -234,4 +248,4 @@ class CheckBoundsNode : public DecoratedASTNode {
CheckBoundsNode(const CheckBoundsNode &other);
};
}
#endif
#endif
12 changes: 10 additions & 2 deletions downward/src/search/parser/lexical_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ static vector<pair<TokenType, regex>> construct_token_type_expressions() {
{TokenType::INTEGER,
R"([+-]?(infinity|\d+([kmg]\b)?))"},
{TokenType::BOOLEAN, R"(true|false)"},
// TODO: support quoted strings.
{TokenType::STRING, R"("([^"]*)\")"},
{TokenType::LET, R"(let)"},
{TokenType::IDENTIFIER, R"([a-zA-Z_]\w*)"}
};
Expand Down Expand Up @@ -59,7 +61,13 @@ TokenStream split_tokens(const string &text) {
TokenType token_type = type_and_expression.first;
const regex &expression = type_and_expression.second;
if (regex_search(start, end, match, expression)) {
tokens.push_back({utils::tolower(match[1]), token_type});
string value;
if (token_type == TokenType::STRING) {
value = match[2];
} else {
value = utils::tolower(match[1]);
}
tokens.push_back({value, token_type});
start += match[0].length();
has_match = true;
break;
Expand All @@ -86,4 +94,4 @@ TokenStream split_tokens(const string &text) {
}
return TokenStream(move(tokens));
}
}
}
7 changes: 5 additions & 2 deletions downward/src/search/parser/syntax_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static unordered_set<TokenType> literal_tokens {
TokenType::FLOAT,
TokenType::INTEGER,
TokenType::BOOLEAN,
TokenType::STRING,
TokenType::IDENTIFIER
};

Expand Down Expand Up @@ -193,7 +194,8 @@ static ASTNodePtr parse_list(TokenStream &tokens, SyntaxAnalyzerContext &context

static vector<TokenType> PARSE_NODE_TOKEN_TYPES = {
TokenType::LET, TokenType::IDENTIFIER, TokenType::BOOLEAN,
TokenType::INTEGER, TokenType::FLOAT, TokenType::OPENING_BRACKET};
TokenType::STRING, TokenType::INTEGER, TokenType::FLOAT,
TokenType::OPENING_BRACKET};

static ASTNodePtr parse_node(TokenStream &tokens,
SyntaxAnalyzerContext &context) {
Expand All @@ -220,6 +222,7 @@ static ASTNodePtr parse_node(TokenStream &tokens,
return parse_literal(tokens, context);
}
case TokenType::BOOLEAN:
case TokenType::STRING:
case TokenType::INTEGER:
case TokenType::FLOAT:
return parse_literal(tokens, context);
Expand All @@ -244,4 +247,4 @@ ASTNodePtr parse(TokenStream &tokens) {
}
return node;
}
}
}
6 changes: 3 additions & 3 deletions downward/src/search/parser/token_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ string token_type_name(TokenType token_type) {
return "Float";
case TokenType::BOOLEAN:
return "Boolean";
case TokenType::STRING:
return "String";
case TokenType::IDENTIFIER:
return "Identifier";
case TokenType::LET:
return "Let";
case TokenType::PATH:
return "Path";
default:
ABORT("Unknown token type.");
}
Expand All @@ -116,4 +116,4 @@ ostream &operator<<(ostream &out, const Token &token) {
out << "<Type: '" << token.type << "', Value: '" << token.content << "'>";
return out;
}
}
}
6 changes: 3 additions & 3 deletions downward/src/search/parser/token_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ enum class TokenType {
INTEGER,
FLOAT,
BOOLEAN,
STRING,
IDENTIFIER,
LET,
PATH,
LET
};

struct Token {
Expand Down Expand Up @@ -59,4 +59,4 @@ struct hash<parser::TokenType> {
}
};
}
#endif
#endif
3 changes: 2 additions & 1 deletion downward/src/search/plugins/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ BasicType TypeRegistry::NO_TYPE = BasicType(typeid(void), "<no type>");

TypeRegistry::TypeRegistry() {
insert_basic_type<bool>();
insert_basic_type<string>();
insert_basic_type<int>();
insert_basic_type<double>();
}
Expand Down Expand Up @@ -345,4 +346,4 @@ const Type &TypeRegistry::get_nonlist_type(type_index type) const {
}
return *registered_types.at(type);
}
}
}
1 change: 1 addition & 0 deletions learner/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ saved_models*
data
lifted
plans
plots

slg
flg
Expand Down
Loading