Skip to content

Commit

Permalink
Merge pull request #2 from DillonZChen/kernels
Browse files Browse the repository at this point in the history
string parser for downard
  • Loading branch information
DillonZChen authored Aug 20, 2023
2 parents 20b0d4a + b81c54e commit dc48878
Show file tree
Hide file tree
Showing 61 changed files with 1,093 additions and 568 deletions.
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,31 @@ python3 run_gnn.py <DOMAIN_PDDL> <TASK_PDDL> -m <WEIGHTS_FILE> -r <REPRESENTATIO

### Training
#### Loading the training dataset
Requires access to `plan_objects.zip`. Also requires packages in `requirements.txt` or alternatively use the singularity
container as in [Search Evaluation](#search-evaluation). Perform the following steps
- enter the ```learner``` directory
- create ```data``` directory in the ```learner``` directory
- unzip ```plan_objects.zip``` and put into ```data``` (there should now be a directory
```path_to_goose/learner/data/plan_objects```)
- run the following while in the ```learner``` directory:
Requires access to `plan_objects.zip`. Also requires packages in `requirements.txt` using for example a virtual environment
and `pip install -r requirements.txt`, or alternatively use the singularity container as in [Search](#search). Perform the
following steps
- enter the `learner` directory
- create `data` directory in the `learner` directory
- unzip `plan_objects.zip` and put into `data` (there should now be a directory
`path_to_goose/learner/data/plan_objects`)
- run the following while in the `learner` directory:
```
python3 scripts/generate_graphs.py llg
python3 generate_graphs_gnn.py --regenerate <REPRESENTATION>
```
for <REPRESENTATION> from `llg, dlg, slg, glg, flg` or generate them all at once with
```
sh dataset/generate_all_graphs_gnn.sh
```

#### Domain-dependent training
Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search
Evaluation](#search-evaluation). To train, go into ```learner``` directory (`cd learner`). Then run
Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search](#search). To train, go
into ```learner``` directory (`cd learner`) and run
```
python3 train_gnn.py -m RGNN -r llg -d goose-<DOMAIN>-only --save-file <SAVE_FILE>
```
where you replace ```<DOMAIN>``` by any domain from ```blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall,
visitsome``` and ```<SAVE_FILE>``` is the name of the save file ending in `.dt` for the trained weights of the models which
would then be located in ```trained_models/<SAVE_FILE>``` after training.
where you replace `<DOMAIN>` by any domain from `blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall,
visitsome` and `<SAVE_FILE>` is the name of the save file ending in `.dt` for the trained weights of the models which
would then be located in `trained_models/<SAVE_FILE>` after training.

## Kernels
### Search
Expand Down
80 changes: 17 additions & 63 deletions downward/src/search/heuristics/goose_heuristic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ using std::string;
namespace goose_heuristic {
GooseHeuristic::GooseHeuristic(const plugins::Options &opts)
: Heuristic(opts) {

initialise_model(opts);
initialise_fact_strings();

}

void GooseHeuristic::initialise_model(const plugins::Options &opts) {
Expand All @@ -42,46 +40,10 @@ void GooseHeuristic::initialise_model(const plugins::Options &opts) {
// python will be printed to stderr, even if it is not an error.
sys.attr("stderr") = sys.attr("stdout");

// A really disgusting hack because FeaturePlugin cannot parse string options
std::string config_path;
switch (opts.get<int>("graph"))
{
case 0: config_path = "slg"; break;
case 1: config_path = "flg"; break;
case 2: config_path = "dlg"; break;
case 3: config_path = "llg"; break;
default:
std::cout << "Unknown enum of graph representation" << std::endl;
exit(-1);
}

// Parse paths from file at config_path
std::string model_path;
std::string domain_file;
std::string instance_file;

std::string line;
std::ifstream config_file(config_path);
int file_line = 0;
while (getline(config_file, line)) {
switch (file_line) {
case 0:
model_path = line;
break;
case 1:
domain_file = line;
break;
case 2:
instance_file = line;
break;
default:
std::cout << "config file " << config_path
<< " must only have 3 lines" << std::endl;
exit(-1);
}
file_line++;
}
config_file.close();
// Read paths
std::string model_path = opts.get<string>("model_path");
std::string domain_file = opts.get<string>("domain_file");
std::string instance_file = opts.get<string>("instance_file");

// Throw everything into Python code
std::cout << "Trying to load model from file " << model_path << " ...\n";
Expand Down Expand Up @@ -187,27 +149,19 @@ class GooseHeuristicFeature : public plugins::TypedFeature<Evaluator, GooseHeuri
document_title("GOOSE heuristic");
document_synopsis("TODO");

add_option<int>(
"graph",
"0: slg, 1: flg, 2: llg, 3: glg",
"-1");

// add_option does not work with <string>

// add_option<string>(
// "model_path",
// "path to trained model weights of file type .dt",
// "default_value.dt");

// add_option<string>(
// "domain_file",
// "Path to the domain file.",
// "default_file.pddl");

// add_option<string>(
// "instance_file",
// "Path to the instance file.",
// "default_file.pddl");
// https://github.com/aibasel/downward/pull/170 for string options
add_option<string>(
"model_path",
"path to trained model weights of file type .dt",
"default_value.dt");
add_option<string>(
"domain_file",
"Path to the domain file.",
"default_file.pddl");
add_option<string>(
"instance_file",
"Path to the instance file.",
"default_file.pddl");

Heuristic::add_options_to_feature(*this);

Expand Down
6 changes: 5 additions & 1 deletion downward/src/search/parser/abstract_syntax_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ DecoratedASTNodePtr LiteralNode::decorate(DecorateContext &context) const {
switch (value.type) {
case TokenType::BOOLEAN:
return utils::make_unique_ptr<BoolLiteralNode>(value.content);
case TokenType::STRING:
return utils::make_unique_ptr<StringLiteralNode>(value.content);
case TokenType::INTEGER:
return utils::make_unique_ptr<IntLiteralNode>(value.content);
case TokenType::FLOAT:
Expand All @@ -440,6 +442,8 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const {
switch (value.type) {
case TokenType::BOOLEAN:
return plugins::TypeRegistry::instance()->get_type<bool>();
case TokenType::STRING:
return plugins::TypeRegistry::instance()->get_type<string>();
case TokenType::INTEGER:
return plugins::TypeRegistry::instance()->get_type<int>();
case TokenType::FLOAT:
Expand All @@ -454,4 +458,4 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const {
token_type_name(value.type) + "'.");
}
}
}
}
27 changes: 26 additions & 1 deletion downward/src/search/parser/decorated_abstract_syntax_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,19 @@ void BoolLiteralNode::dump(string indent) const {
cout << indent << "BOOL: " << value << endl;
}

StringLiteralNode::StringLiteralNode(const string &value)
: value(value) {
}

plugins::Any StringLiteralNode::construct(ConstructContext &context) const {
utils::TraceBlock block(context, "Constructing string value from '" + value + "'");
return value;
}

void StringLiteralNode::dump(string indent) const {
cout << indent << "STRING: " << value << endl;
}

IntLiteralNode::IntLiteralNode(const string &value)
: value(value) {
}
Expand Down Expand Up @@ -473,6 +486,18 @@ shared_ptr<DecoratedASTNode> BoolLiteralNode::clone_shared() const {
return make_shared<BoolLiteralNode>(*this);
}

StringLiteralNode::StringLiteralNode(const StringLiteralNode &other)
: value(other.value) {
}

unique_ptr<DecoratedASTNode> StringLiteralNode::clone() const {
return utils::make_unique_ptr<StringLiteralNode>(*this);
}

shared_ptr<DecoratedASTNode> StringLiteralNode::clone_shared() const {
return make_shared<StringLiteralNode>(*this);
}

IntLiteralNode::IntLiteralNode(const IntLiteralNode &other)
: value(other.value) {
}
Expand Down Expand Up @@ -534,4 +559,4 @@ unique_ptr<DecoratedASTNode> CheckBoundsNode::clone() const {
shared_ptr<DecoratedASTNode> CheckBoundsNode::clone_shared() const {
return make_shared<CheckBoundsNode>(*this);
}
}
}
16 changes: 15 additions & 1 deletion downward/src/search/parser/decorated_abstract_syntax_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,20 @@ class BoolLiteralNode : public DecoratedASTNode {
BoolLiteralNode(const BoolLiteralNode &other);
};

class StringLiteralNode : public DecoratedASTNode {
std::string value;
public:
StringLiteralNode(const std::string &value);

plugins::Any construct(ConstructContext &context) const override;
void dump(std::string indent) const override;

// TODO: once we get rid of lazy construction, this should no longer be necessary.
virtual std::unique_ptr<DecoratedASTNode> clone() const override;
virtual std::shared_ptr<DecoratedASTNode> clone_shared() const override;
StringLiteralNode(const StringLiteralNode &other);
};

class IntLiteralNode : public DecoratedASTNode {
std::string value;
public:
Expand Down Expand Up @@ -234,4 +248,4 @@ class CheckBoundsNode : public DecoratedASTNode {
CheckBoundsNode(const CheckBoundsNode &other);
};
}
#endif
#endif
12 changes: 10 additions & 2 deletions downward/src/search/parser/lexical_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ static vector<pair<TokenType, regex>> construct_token_type_expressions() {
{TokenType::INTEGER,
R"([+-]?(infinity|\d+([kmg]\b)?))"},
{TokenType::BOOLEAN, R"(true|false)"},
// TODO: support quoted strings.
{TokenType::STRING, R"("([^"]*)\")"},
{TokenType::LET, R"(let)"},
{TokenType::IDENTIFIER, R"([a-zA-Z_]\w*)"}
};
Expand Down Expand Up @@ -59,7 +61,13 @@ TokenStream split_tokens(const string &text) {
TokenType token_type = type_and_expression.first;
const regex &expression = type_and_expression.second;
if (regex_search(start, end, match, expression)) {
tokens.push_back({utils::tolower(match[1]), token_type});
string value;
if (token_type == TokenType::STRING) {
value = match[2];
} else {
value = utils::tolower(match[1]);
}
tokens.push_back({value, token_type});
start += match[0].length();
has_match = true;
break;
Expand All @@ -86,4 +94,4 @@ TokenStream split_tokens(const string &text) {
}
return TokenStream(move(tokens));
}
}
}
7 changes: 5 additions & 2 deletions downward/src/search/parser/syntax_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ static unordered_set<TokenType> literal_tokens {
TokenType::FLOAT,
TokenType::INTEGER,
TokenType::BOOLEAN,
TokenType::STRING,
TokenType::IDENTIFIER
};

Expand Down Expand Up @@ -193,7 +194,8 @@ static ASTNodePtr parse_list(TokenStream &tokens, SyntaxAnalyzerContext &context

static vector<TokenType> PARSE_NODE_TOKEN_TYPES = {
TokenType::LET, TokenType::IDENTIFIER, TokenType::BOOLEAN,
TokenType::INTEGER, TokenType::FLOAT, TokenType::OPENING_BRACKET};
TokenType::STRING, TokenType::INTEGER, TokenType::FLOAT,
TokenType::OPENING_BRACKET};

static ASTNodePtr parse_node(TokenStream &tokens,
SyntaxAnalyzerContext &context) {
Expand All @@ -220,6 +222,7 @@ static ASTNodePtr parse_node(TokenStream &tokens,
return parse_literal(tokens, context);
}
case TokenType::BOOLEAN:
case TokenType::STRING:
case TokenType::INTEGER:
case TokenType::FLOAT:
return parse_literal(tokens, context);
Expand All @@ -244,4 +247,4 @@ ASTNodePtr parse(TokenStream &tokens) {
}
return node;
}
}
}
6 changes: 3 additions & 3 deletions downward/src/search/parser/token_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ string token_type_name(TokenType token_type) {
return "Float";
case TokenType::BOOLEAN:
return "Boolean";
case TokenType::STRING:
return "String";
case TokenType::IDENTIFIER:
return "Identifier";
case TokenType::LET:
return "Let";
case TokenType::PATH:
return "Path";
default:
ABORT("Unknown token type.");
}
Expand All @@ -116,4 +116,4 @@ ostream &operator<<(ostream &out, const Token &token) {
out << "<Type: '" << token.type << "', Value: '" << token.content << "'>";
return out;
}
}
}
6 changes: 3 additions & 3 deletions downward/src/search/parser/token_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ enum class TokenType {
INTEGER,
FLOAT,
BOOLEAN,
STRING,
IDENTIFIER,
LET,
PATH,
LET
};

struct Token {
Expand Down Expand Up @@ -59,4 +59,4 @@ struct hash<parser::TokenType> {
}
};
}
#endif
#endif
3 changes: 2 additions & 1 deletion downward/src/search/plugins/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ BasicType TypeRegistry::NO_TYPE = BasicType(typeid(void), "<no type>");

TypeRegistry::TypeRegistry() {
insert_basic_type<bool>();
insert_basic_type<string>();
insert_basic_type<int>();
insert_basic_type<double>();
}
Expand Down Expand Up @@ -345,4 +346,4 @@ const Type &TypeRegistry::get_nonlist_type(type_index type) const {
}
return *registered_types.at(type);
}
}
}
1 change: 1 addition & 0 deletions learner/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ saved_models*
data
lifted
plans
plots

slg
flg
Expand Down
Loading

0 comments on commit dc48878

Please sign in to comment.