diff --git a/README.md b/README.md index a8dac4d0..53da4a35 100644 --- a/README.md +++ b/README.md @@ -41,26 +41,31 @@ python3 run_gnn.py -m -r +``` +for from `llg, dlg, slg, glg, flg` or generate them all at once with +``` +sh dataset/generate_all_graphs_gnn.sh ``` #### Domain-dependent training -Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search -Evaluation](#search-evaluation). To train, go into ```learner``` directory (`cd learner`). Then run +Requires packages in `requirements.txt` or alternatively use the singularity container as in [Search](#search). To train, go +into ```learner``` directory (`cd learner`) and run ``` python3 train_gnn.py -m RGNN -r llg -d goose--only --save-file ``` -where you replace `````` by any domain from ```blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall, -visitsome``` and `````` is the name of the save file ending in `.dt` for the trained weights of the models which -would then be located in ```trained_models/``` after training. +where you replace `` by any domain from `blocks, ferry, gripper, n-puzzle, sokoban, spanner, visitall, +visitsome` and `` is the name of the save file ending in `.dt` for the trained weights of the models which +would then be located in `trained_models/` after training. ## Kernels ### Search diff --git a/downward/src/search/heuristics/goose_heuristic.cc b/downward/src/search/heuristics/goose_heuristic.cc index 7e1ed088..48a678b7 100644 --- a/downward/src/search/heuristics/goose_heuristic.cc +++ b/downward/src/search/heuristics/goose_heuristic.cc @@ -13,10 +13,8 @@ using std::string; namespace goose_heuristic { GooseHeuristic::GooseHeuristic(const plugins::Options &opts) : Heuristic(opts) { - initialise_model(opts); initialise_fact_strings(); - } void GooseHeuristic::initialise_model(const plugins::Options &opts) { @@ -42,46 +40,10 @@ void GooseHeuristic::initialise_model(const plugins::Options &opts) { // python will be printed to stderr, even if it is not an error. sys.attr("stderr") = sys.attr("stdout"); - // A really disgusting hack because FeaturePlugin cannot parse string options - std::string config_path; - switch (opts.get("graph")) - { - case 0: config_path = "slg"; break; - case 1: config_path = "flg"; break; - case 2: config_path = "dlg"; break; - case 3: config_path = "llg"; break; - default: - std::cout << "Unknown enum of graph representation" << std::endl; - exit(-1); - } - - // Parse paths from file at config_path - std::string model_path; - std::string domain_file; - std::string instance_file; - - std::string line; - std::ifstream config_file(config_path); - int file_line = 0; - while (getline(config_file, line)) { - switch (file_line) { - case 0: - model_path = line; - break; - case 1: - domain_file = line; - break; - case 2: - instance_file = line; - break; - default: - std::cout << "config file " << config_path - << " must only have 3 lines" << std::endl; - exit(-1); - } - file_line++; - } - config_file.close(); + // Read paths + std::string model_path = opts.get("model_path"); + std::string domain_file = opts.get("domain_file"); + std::string instance_file = opts.get("instance_file"); // Throw everything into Python code std::cout << "Trying to load model from file " << model_path << " ...\n"; @@ -187,27 +149,19 @@ class GooseHeuristicFeature : public plugins::TypedFeature( - "graph", - "0: slg, 1: flg, 2: llg, 3: glg", - "-1"); - - // add_option does not work with - - // add_option( - // "model_path", - // "path to trained model weights of file type .dt", - // "default_value.dt"); - - // add_option( - // "domain_file", - // "Path to the domain file.", - // "default_file.pddl"); - - // add_option( - // "instance_file", - // "Path to the instance file.", - // "default_file.pddl"); + // https://github.com/aibasel/downward/pull/170 for string options + add_option( + "model_path", + "path to trained model weights of file type .dt", + "default_value.dt"); + add_option( + "domain_file", + "Path to the domain file.", + "default_file.pddl"); + add_option( + "instance_file", + "Path to the instance file.", + "default_file.pddl"); Heuristic::add_options_to_feature(*this); diff --git a/downward/src/search/parser/abstract_syntax_tree.cc b/downward/src/search/parser/abstract_syntax_tree.cc index 5aecdb72..27e5b670 100644 --- a/downward/src/search/parser/abstract_syntax_tree.cc +++ b/downward/src/search/parser/abstract_syntax_tree.cc @@ -419,6 +419,8 @@ DecoratedASTNodePtr LiteralNode::decorate(DecorateContext &context) const { switch (value.type) { case TokenType::BOOLEAN: return utils::make_unique_ptr(value.content); + case TokenType::STRING: + return utils::make_unique_ptr(value.content); case TokenType::INTEGER: return utils::make_unique_ptr(value.content); case TokenType::FLOAT: @@ -440,6 +442,8 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const { switch (value.type) { case TokenType::BOOLEAN: return plugins::TypeRegistry::instance()->get_type(); + case TokenType::STRING: + return plugins::TypeRegistry::instance()->get_type(); case TokenType::INTEGER: return plugins::TypeRegistry::instance()->get_type(); case TokenType::FLOAT: @@ -454,4 +458,4 @@ const plugins::Type &LiteralNode::get_type(DecorateContext &context) const { token_type_name(value.type) + "'."); } } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/decorated_abstract_syntax_tree.cc b/downward/src/search/parser/decorated_abstract_syntax_tree.cc index 3a401d9e..068ee593 100644 --- a/downward/src/search/parser/decorated_abstract_syntax_tree.cc +++ b/downward/src/search/parser/decorated_abstract_syntax_tree.cc @@ -218,6 +218,19 @@ void BoolLiteralNode::dump(string indent) const { cout << indent << "BOOL: " << value << endl; } +StringLiteralNode::StringLiteralNode(const string &value) + : value(value) { +} + +plugins::Any StringLiteralNode::construct(ConstructContext &context) const { + utils::TraceBlock block(context, "Constructing string value from '" + value + "'"); + return value; +} + +void StringLiteralNode::dump(string indent) const { + cout << indent << "STRING: " << value << endl; +} + IntLiteralNode::IntLiteralNode(const string &value) : value(value) { } @@ -473,6 +486,18 @@ shared_ptr BoolLiteralNode::clone_shared() const { return make_shared(*this); } +StringLiteralNode::StringLiteralNode(const StringLiteralNode &other) + : value(other.value) { +} + +unique_ptr StringLiteralNode::clone() const { + return utils::make_unique_ptr(*this); +} + +shared_ptr StringLiteralNode::clone_shared() const { + return make_shared(*this); +} + IntLiteralNode::IntLiteralNode(const IntLiteralNode &other) : value(other.value) { } @@ -534,4 +559,4 @@ unique_ptr CheckBoundsNode::clone() const { shared_ptr CheckBoundsNode::clone_shared() const { return make_shared(*this); } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/decorated_abstract_syntax_tree.h b/downward/src/search/parser/decorated_abstract_syntax_tree.h index 0094f887..6561560e 100644 --- a/downward/src/search/parser/decorated_abstract_syntax_tree.h +++ b/downward/src/search/parser/decorated_abstract_syntax_tree.h @@ -157,6 +157,20 @@ class BoolLiteralNode : public DecoratedASTNode { BoolLiteralNode(const BoolLiteralNode &other); }; +class StringLiteralNode : public DecoratedASTNode { + std::string value; +public: + StringLiteralNode(const std::string &value); + + plugins::Any construct(ConstructContext &context) const override; + void dump(std::string indent) const override; + + // TODO: once we get rid of lazy construction, this should no longer be necessary. + virtual std::unique_ptr clone() const override; + virtual std::shared_ptr clone_shared() const override; + StringLiteralNode(const StringLiteralNode &other); +}; + class IntLiteralNode : public DecoratedASTNode { std::string value; public: @@ -234,4 +248,4 @@ class CheckBoundsNode : public DecoratedASTNode { CheckBoundsNode(const CheckBoundsNode &other); }; } -#endif +#endif \ No newline at end of file diff --git a/downward/src/search/parser/lexical_analyzer.cc b/downward/src/search/parser/lexical_analyzer.cc index a127aed9..f31f230d 100644 --- a/downward/src/search/parser/lexical_analyzer.cc +++ b/downward/src/search/parser/lexical_analyzer.cc @@ -29,6 +29,8 @@ static vector> construct_token_type_expressions() { {TokenType::INTEGER, R"([+-]?(infinity|\d+([kmg]\b)?))"}, {TokenType::BOOLEAN, R"(true|false)"}, + // TODO: support quoted strings. + {TokenType::STRING, R"("([^"]*)\")"}, {TokenType::LET, R"(let)"}, {TokenType::IDENTIFIER, R"([a-zA-Z_]\w*)"} }; @@ -59,7 +61,13 @@ TokenStream split_tokens(const string &text) { TokenType token_type = type_and_expression.first; const regex &expression = type_and_expression.second; if (regex_search(start, end, match, expression)) { - tokens.push_back({utils::tolower(match[1]), token_type}); + string value; + if (token_type == TokenType::STRING) { + value = match[2]; + } else { + value = utils::tolower(match[1]); + } + tokens.push_back({value, token_type}); start += match[0].length(); has_match = true; break; @@ -86,4 +94,4 @@ TokenStream split_tokens(const string &text) { } return TokenStream(move(tokens)); } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/syntax_analyzer.cc b/downward/src/search/parser/syntax_analyzer.cc index ffcafbfa..62f4fbc3 100644 --- a/downward/src/search/parser/syntax_analyzer.cc +++ b/downward/src/search/parser/syntax_analyzer.cc @@ -162,6 +162,7 @@ static unordered_set literal_tokens { TokenType::FLOAT, TokenType::INTEGER, TokenType::BOOLEAN, + TokenType::STRING, TokenType::IDENTIFIER }; @@ -193,7 +194,8 @@ static ASTNodePtr parse_list(TokenStream &tokens, SyntaxAnalyzerContext &context static vector PARSE_NODE_TOKEN_TYPES = { TokenType::LET, TokenType::IDENTIFIER, TokenType::BOOLEAN, - TokenType::INTEGER, TokenType::FLOAT, TokenType::OPENING_BRACKET}; + TokenType::STRING, TokenType::INTEGER, TokenType::FLOAT, + TokenType::OPENING_BRACKET}; static ASTNodePtr parse_node(TokenStream &tokens, SyntaxAnalyzerContext &context) { @@ -220,6 +222,7 @@ static ASTNodePtr parse_node(TokenStream &tokens, return parse_literal(tokens, context); } case TokenType::BOOLEAN: + case TokenType::STRING: case TokenType::INTEGER: case TokenType::FLOAT: return parse_literal(tokens, context); @@ -244,4 +247,4 @@ ASTNodePtr parse(TokenStream &tokens) { } return node; } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/token_stream.cc b/downward/src/search/parser/token_stream.cc index 7879be17..24695feb 100644 --- a/downward/src/search/parser/token_stream.cc +++ b/downward/src/search/parser/token_stream.cc @@ -96,12 +96,12 @@ string token_type_name(TokenType token_type) { return "Float"; case TokenType::BOOLEAN: return "Boolean"; + case TokenType::STRING: + return "String"; case TokenType::IDENTIFIER: return "Identifier"; case TokenType::LET: return "Let"; - case TokenType::PATH: - return "Path"; default: ABORT("Unknown token type."); } @@ -116,4 +116,4 @@ ostream &operator<<(ostream &out, const Token &token) { out << ""; return out; } -} +} \ No newline at end of file diff --git a/downward/src/search/parser/token_stream.h b/downward/src/search/parser/token_stream.h index 74420c26..01daaddf 100644 --- a/downward/src/search/parser/token_stream.h +++ b/downward/src/search/parser/token_stream.h @@ -19,9 +19,9 @@ enum class TokenType { INTEGER, FLOAT, BOOLEAN, + STRING, IDENTIFIER, - LET, - PATH, + LET }; struct Token { @@ -59,4 +59,4 @@ struct hash { } }; } -#endif +#endif \ No newline at end of file diff --git a/downward/src/search/plugins/types.cc b/downward/src/search/plugins/types.cc index 117c139b..d694f834 100644 --- a/downward/src/search/plugins/types.cc +++ b/downward/src/search/plugins/types.cc @@ -292,6 +292,7 @@ BasicType TypeRegistry::NO_TYPE = BasicType(typeid(void), ""); TypeRegistry::TypeRegistry() { insert_basic_type(); + insert_basic_type(); insert_basic_type(); insert_basic_type(); } @@ -345,4 +346,4 @@ const Type &TypeRegistry::get_nonlist_type(type_index type) const { } return *registered_types.at(type); } -} +} \ No newline at end of file diff --git a/learner/.gitignore b/learner/.gitignore index 21235fd7..2f289fa4 100644 --- a/learner/.gitignore +++ b/learner/.gitignore @@ -13,6 +13,7 @@ saved_models* data lifted plans +plots slg flg diff --git a/learner/dataset/dataset.py b/learner/dataset/dataset.py index ed80a53f..ffae20e4 100644 --- a/learner/dataset/dataset.py +++ b/learner/dataset/dataset.py @@ -1,39 +1,54 @@ import os import sys - sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import random +import numpy as np from util.stats import get_stats from torch_geometric.loader import DataLoader from sklearn.model_selection import train_test_split -from dataset.graphs import get_graph_data +from dataset.graphs_gnn import get_graph_data as get_graph_data_gnn +from dataset.graphs_kernel import get_graph_data as get_graph_data_kernel from dataset.transform import preprocess_data -def get_loaders_from_args(args): +def get_loaders_from_args_gnn(args): + model_name = args.model + batch_size = args.batch_size + domain = args.domain + rep = args.rep + max_nodes = args.max_nodes + cutoff = args.cutoff + small_train = args.small_train + num_workers = 0 + pin_memory = True + + dataset = get_graph_data_gnn(domain=domain, representation=rep) + dataset = preprocess_data(model_name, data_list=dataset, c_hi=cutoff, n_hi=max_nodes, small_train=small_train) + get_stats(dataset=dataset, desc="Whole dataset") + + trainset, valset = train_test_split(dataset, test_size=0.15, random_state=4550) + + get_stats(dataset=trainset, desc="Train set") + get_stats(dataset=valset, desc="Val set") + print("train size:", len(trainset)) + print("validation size:", len(valset)) - model_name = args.model - batch_size = args.batch_size - domain = args.domain - rep = args.rep - max_nodes = args.max_nodes - cutoff = args.cutoff - small_train = args.small_train - num_workers = 0 - pin_memory = True + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory, num_workers=num_workers) + val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory, num_workers=num_workers) - dataset = get_graph_data(domain=domain, representation=rep) - dataset = preprocess_data(model_name, data_list=dataset, c_hi=cutoff, n_hi=max_nodes, small_train=small_train) - get_stats(dataset=dataset, desc="Whole dataset") + return train_loader, val_loader - trainset, valset = train_test_split(dataset, test_size=0.15, random_state=4550) +def get_dataset_from_args_kernels(args): + rep = args.rep + domain = args.domain - get_stats(dataset=trainset, desc="Train set") - get_stats(dataset=valset, desc="Val set") - print("train size:", len(trainset)) - print("validation size:", len(valset)) + dataset = get_graph_data_kernel(domain=domain, representation=rep) + if args.small_train: + dataset = random.sample(dataset, min(len(dataset, 1000))) + get_stats(dataset=dataset, desc="Whole dataset") - train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=pin_memory, num_workers=num_workers) - val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, pin_memory=pin_memory, num_workers=num_workers) + graphs = [data[0] for data in dataset] + y = np.array([data[1] for data in dataset]) - return train_loader, val_loader + return graphs, y diff --git a/learner/dataset/generate_all_graphs_gnn.sh b/learner/dataset/generate_all_graphs_gnn.sh new file mode 100644 index 00000000..dbb4a90f --- /dev/null +++ b/learner/dataset/generate_all_graphs_gnn.sh @@ -0,0 +1,5 @@ +for rep in llg slg dlg glg flg +do + echo "python3 dataset/generate_graphs_gnn.py $rep --regenerate" + python3 dataset/generate_graphs_gnn.py $rep --regenerate +done diff --git a/learner/dataset/generate_all_graphs_kernel.sh b/learner/dataset/generate_all_graphs_kernel.sh new file mode 100644 index 00000000..dddbe333 --- /dev/null +++ b/learner/dataset/generate_all_graphs_kernel.sh @@ -0,0 +1,5 @@ +for rep in llg slg dlg glg flg +do + echo "python3 dataset/generate_graphs_kernel.py $rep --regenerate" + python3 dataset/generate_graphs_kernel.py $rep --regenerate +done diff --git a/learner/scripts/generate_graphs.py b/learner/dataset/generate_graphs_gnn.py similarity index 93% rename from learner/scripts/generate_graphs.py rename to learner/dataset/generate_graphs_gnn.py index e08c98b7..b7e746eb 100644 --- a/learner/scripts/generate_graphs.py +++ b/learner/dataset/generate_graphs_gnn.py @@ -3,7 +3,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import argparse from representation import REPRESENTATIONS -from dataset.graphs import gen_graph_rep +from .graphs_gnn import gen_graph_rep if __name__ == "__main__": diff --git a/learner/dataset/generate_graphs_kernel.py b/learner/dataset/generate_graphs_kernel.py new file mode 100644 index 00000000..862c6945 --- /dev/null +++ b/learner/dataset/generate_graphs_kernel.py @@ -0,0 +1,20 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) +import argparse +from representation import REPRESENTATIONS +from .graphs_kernel import gen_graph_rep + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('rep', type=str, help="graph representation to generate", choices=REPRESENTATIONS) + parser.add_argument('-d', '--domain', type=str, help="domain to generate (useful for debugging)") + parser.add_argument('--regenerate', action="store_true") + args = parser.parse_args() + + rep = args.rep + gen_graph_rep(representation=rep, + regenerate=args.regenerate, + domain=args.domain) + \ No newline at end of file diff --git a/learner/dataset/graphs.py b/learner/dataset/graphs_gnn.py similarity index 88% rename from learner/dataset/graphs.py rename to learner/dataset/graphs_gnn.py index e7abac68..c019475e 100644 --- a/learner/dataset/graphs.py +++ b/learner/dataset/graphs_gnn.py @@ -1,4 +1,4 @@ -""" File for generating and loading graphs. See scripts/generate_graphs.py """ +""" File for generating and loading graphs for GNNs. Used by scripts/generate_graphs_gnn.py """ import os import sys @@ -18,6 +18,8 @@ from dataset.goose_domain_info import get_train_goose_instance_files +_SAVE_DIR = "data/graphs_gnn" + def generate_graph_from_domain_problem_pddl( domain_name: str, domain_pddl: str, @@ -27,9 +29,6 @@ def generate_graph_from_domain_problem_pddl( """ Generates a list of graphs corresponding to states in the optimal plan """ ret = [] - if representation=="dlg": - return slg_to_dlg(domain_name, domain_pddl, problem_pddl) - plan = optimal_plan_exists(domain_name, domain_pddl, problem_pddl) if plan is None: return None @@ -44,7 +43,7 @@ def generate_graph_from_domain_problem_pddl( if REPRESENTATIONS[representation].lifted: s = rep.str_to_state(s) - x, edge_index = rep.get_state_enc(s) + x, edge_index = rep.state_to_tensor(s) applicable_action=None # requires refactoring representation classes graph_data = Data( x=x, @@ -58,18 +57,6 @@ def generate_graph_from_domain_problem_pddl( ret.append(graph_data) return ret -def slg_to_dlg(domain_name, domain_pddl, problem_pddl): - problem_name = os.path.basename(problem_pddl).replace(".pddl", "") - f = f"data/graphs/sdg-el/{domain_name}/{problem_name}.data" - if not os.path.exists(f): - return None - graph_data_list = torch.load(f) - ret = [] - for graph in graph_data_list: - graph.edge_index = graph.edge_index[:2] - ret.append(graph) - return ret - def get_graph_data( representation: str, domain: str="all", @@ -78,7 +65,7 @@ def get_graph_data( print("Loading train data...") print("NOTE: the data has been precomputed and saved.") - print("Rerun gen_data/graphs.py if representation has been updated!") + print("Exec 'python3 scripts/generate_graphs_gnn.py --regenerate' if representation has been updated!") path = get_data_dir_path(representation=representation) print(f"Path to data: {path}") @@ -95,10 +82,7 @@ def get_graph_data( elif domain == "ipc-only": # codebase getting bloated if "ipc-" not in domain_name: continue - elif domain == "goose-pretraining": # ipc + goose - if domain_name in goose_domain_info.DOMAINS_NOT_TO_TRAIN or "htg-" in domain_name: - continue - elif domain == "goose-unseen-pretraining": # ipc only + elif domain == "goose-di": # ipc only if domain_name in goose_domain_info.DOMAINS_NOT_TO_TRAIN or "htg-" in domain_name or "goose-" in domain_name: continue else: @@ -178,7 +162,7 @@ def gen_graph_rep( return def get_data_dir_path(representation: str) -> str: - save_dir = f'data/graphs/{representation}' + save_dir = f'{_SAVE_DIR}/{representation}' os.makedirs(save_dir, exist_ok=True) return save_dir diff --git a/learner/dataset/graphs_kernel.py b/learner/dataset/graphs_kernel.py new file mode 100644 index 00000000..41e59d9a --- /dev/null +++ b/learner/dataset/graphs_kernel.py @@ -0,0 +1,203 @@ +""" File for generating and loading graphs for kernels. Used by scripts/generate_graphs_kernel.py """ + +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +import torch +import dataset.ipc_domain_info as ipc_domain_info +import dataset.htg_domain_info as htg_domain_info +import dataset.goose_domain_info as goose_domain_info + +from tqdm import tqdm, trange +from typing import Dict, List, Optional, Tuple +from representation import REPRESENTATIONS +from dataset.htg_domain_info import get_all_htg_instance_files +from dataset.ipc_domain_info import same_domain, GROUNDED_DOMAINS, get_ipc_domain_problem_files +from dataset.goose_domain_info import get_train_goose_instance_files +from representation import CGraph + + +_SAVE_DIR = "data/graphs_kernel" +Data = Tuple[CGraph, int] + +def generate_graph_from_domain_problem_pddl( + domain_name: str, + domain_pddl: str, + problem_pddl: str, + representation: str, +) -> Optional[List[Data]]: + """ Generates a list of graphs corresponding to states in the optimal plan """ + ret = [] + + plan = optimal_plan_exists(domain_name, domain_pddl, problem_pddl) + if plan is None: + return None + + # see representation package + rep = REPRESENTATIONS[representation](domain_pddl, problem_pddl) + rep.convert_to_coloured_graph() + + problem_name = os.path.basename(problem_pddl).replace(".pddl", "") + + for s, y, a in plan: + if REPRESENTATIONS[representation].lifted: + s = rep.str_to_state(s) + + graph = rep.state_to_cgraph(s) + ret.append((graph, y)) + return ret + +def get_graph_data( + representation: str, + domain: str="all", +) -> List[Data]: + """ Load stored generated graphs """ + + print("Loading train data...") + print("NOTE: the data has been precomputed and saved.") + print("Exec 'python3 scripts/generate_graphs_kernel.py --regenerate' if representation has been updated!") + + path = get_data_dir_path(representation=representation) + print(f"Path to data: {path}") + + ret = [] + for domain_name in sorted(list(os.listdir(path))): + if ".data" in domain_name: + continue + if domain_name in ipc_domain_info.GENERAL_COST_DOMAINS or domain_name in htg_domain_info.GENERAL_COST_DOMAINS: + # tqdm.write(f"\t{domain_name} skipped since it does not have unit costs") + continue + if domain == "all": + pass # accept everything + elif domain == "ipc-only": # codebase getting bloated + if "ipc-" not in domain_name: + continue + elif domain == "goose-di": # ipc only + if domain_name in goose_domain_info.DOMAINS_NOT_TO_TRAIN or "htg-" in domain_name or "goose-" in domain_name: + continue + else: + if "-only" not in domain and not same_domain(domain, domain_name): + continue + elif "-only" in domain and domain.replace("-only", "")!=domain_name: + continue + + for data in sorted(list(os.listdir(f"{path}/{domain_name}"))): + next_data = torch.load(f'{path}/{domain_name}/{data}') + ret+=next_data + + print(f"{domain} dataset of size {len(ret)} loaded!") + return ret + +def generate_graph_rep_domain( + domain_name: str, + domain_pddl: str, + problem_pddl: str, + representation: str, + regenerate: bool +) -> int: + """ Saves list of torch_geometric.data.Data of graphs and features to file. + Returns a new graph was generated or not + """ + save_file = get_data_path(domain_name, + domain_pddl, + problem_pddl, + representation) + if os.path.exists(save_file): + if not regenerate: + return 0 + else: + os.remove(save_file) # make a fresh set of data + + graph = generate_graph_from_domain_problem_pddl(domain_name=domain_name, + domain_pddl=domain_pddl, + problem_pddl=problem_pddl, + representation=representation) + if graph is not None: + tqdm.write(f'saving data @{save_file}...') + torch.save(graph, save_file) + tqdm.write('data saved!') + return 1 + return 0 + +def gen_graph_rep( + representation: str, + regenerate: bool, + domain: str, +) -> None: + """ Generate graph representations from saved optimal plans. """ + + # tasks = get_ipc_domain_problem_files(del_free=False) + # tasks += get_all_htg_instance_files(split=True) + tasks = get_train_goose_instance_files() + + new_generated = 0 + pbar = tqdm(tasks) + for domain_name, domain_pddl, problem_pddl in tasks: + problem_name = os.path.basename(problem_pddl).replace(".pddl", "") + # if representation in LIFTED_REPRESENTATIONS and domain_name in GROUNDED_DOMAINS: + # continue + pbar.set_description(f"Generating {representation} graphs for {domain_name} {problem_name}") + + # in case we only want to generate graphs for one specific domain + if domain is not None and domain != domain_name: + continue + + new_generated += generate_graph_rep_domain(domain_name=domain_name, + domain_pddl=domain_pddl, + problem_pddl=problem_pddl, + representation=representation, + regenerate=regenerate) + print(f"newly generated graphs: {new_generated}") + return + +def get_data_dir_path(representation: str) -> str: + save_dir = f'{_SAVE_DIR}/{representation}' + os.makedirs(save_dir, exist_ok=True) + return save_dir + +def get_data_path(domain_name: str, + domain_pddl: str, + problem_pddl: str, + representation: str) -> str: + """ Get path to save file of graph training data of given domain. """ + problem_name = os.path.basename(problem_pddl).replace(".pddl", "") + save_dir = f'{get_data_dir_path(representation)}/{domain_name}' + save_file = f'{save_dir}/{problem_name}.data' + os.makedirs(save_dir, exist_ok=True) + return save_file + +def optimal_plan_exists(domain_name: str, domain_pddl: str, problem_pddl: str): + domain_name = domain_name.replace("htg-", '') + problem_name = os.path.basename(problem_pddl) + save_dir = f'data/plan_objects/{domain_name}' + save_path = f'{save_dir}/{problem_name}.states'.replace(".pddl", "") + if os.path.exists(save_path): # if plan found, load and return + data = [] + lines = open(save_path, 'r').readlines() + plan_length = len(lines)-1 + for i, line in enumerate(lines): + if line[0]==";": + assert "GOOD" in line + else: + line = line.replace("\n", "") + s = set() + for fact in line.split(): + if "(" not in fact: + lime = f"({fact})" + else: + pred = fact[:fact.index("(")] + fact = fact.replace(pred+"(","").replace(")","") + args = fact.split(",")[:-1] + lime = f"({pred}" + for j, arg in enumerate(args): + lime+=f" {arg}" + if j == len(args)-1: + lime+=")" + s.add(lime) + y = plan_length - i - 1 + a = None + data.append((s, y, a)) + return data + else: + return None diff --git a/learner/dataset/transform.py b/learner/dataset/transform.py index e1505daa..6352f264 100644 --- a/learner/dataset/transform.py +++ b/learner/dataset/transform.py @@ -5,12 +5,12 @@ import random import torch -import models +import gnns from torch import Tensor from typing import Dict, List, Optional, Tuple from torch_geometric.data import DataLoader, Data from tqdm import tqdm, trange -from dataset.graphs import get_graph_data +from dataset.graphs_gnn import get_graph_data def extract_testset_domain( diff --git a/learner/models/__init__.py b/learner/gnns/__init__.py similarity index 100% rename from learner/models/__init__.py rename to learner/gnns/__init__.py diff --git a/learner/models/base_gnn.py b/learner/gnns/base_gnn.py similarity index 97% rename from learner/models/base_gnn.py rename to learner/gnns/base_gnn.py index 5ecf1bac..478fc67e 100644 --- a/learner/models/base_gnn.py +++ b/learner/gnns/base_gnn.py @@ -5,8 +5,7 @@ import time import warnings from planning import Proposition, State -from representation import REPRESENTATIONS -from representation.base_class import Representation +from representation import REPRESENTATIONS, Representation from torch_geometric.nn import (global_add_pool, global_max_pool, global_mean_pool) from abc import ABC, abstractmethod from torch_geometric.nn import MessagePassing @@ -194,7 +193,7 @@ def initialise_readout(self): return def h(self, state: State) -> float: - x, edge_index = self.rep.get_state_enc(state) + x, edge_index = self.rep.state_to_tensor(state) x = x.to(self.device) edge_index = edge_index.to(self.device) h = self.model.forward(x, edge_index, None).item() @@ -204,7 +203,7 @@ def h(self, state: State) -> float: def h_batch(self, states: List[State]) -> List[float]: data_list = [] for state in states: - x, edge_index = self.rep.get_state_enc(state) + x, edge_index = self.rep.state_to_tensor(state) data_list.append(Data(x=x, edge_index=edge_index)) loader = DataLoader(dataset=data_list, batch_size=min(len(data_list), 32)) data = next(iter(loader)).to(self.device) diff --git a/learner/models/elmpnn.py b/learner/gnns/elmpnn.py similarity index 98% rename from learner/models/elmpnn.py rename to learner/gnns/elmpnn.py index 228705a2..3b37e429 100644 --- a/learner/models/elmpnn.py +++ b/learner/gnns/elmpnn.py @@ -61,7 +61,7 @@ def create_model(self, params): self.model = ELMPNN(params) def h(self, state: State) -> float: - x, edge_index = self.rep.get_state_enc(state) + x, edge_index = self.rep.state_to_tensor(state) x = x.to(self.device) for i in range(len(edge_index)): edge_index[i] = edge_index[i].to(self.device) diff --git a/learner/models/loss.py b/learner/gnns/loss.py similarity index 100% rename from learner/models/loss.py rename to learner/gnns/loss.py diff --git a/learner/models/mpnn.py b/learner/gnns/mpnn.py similarity index 100% rename from learner/models/mpnn.py rename to learner/gnns/mpnn.py diff --git a/learner/kernels/__init__.py b/learner/kernels/__init__.py new file mode 100644 index 00000000..bb478c0c --- /dev/null +++ b/learner/kernels/__init__.py @@ -0,0 +1,5 @@ +from .wl import WeisfeilerLehmanKernel + +KERNELS = { + "wl": WeisfeilerLehmanKernel +} \ No newline at end of file diff --git a/learner/kernels/base_kernel.py b/learner/kernels/base_kernel.py new file mode 100644 index 00000000..7bad0e24 --- /dev/null +++ b/learner/kernels/base_kernel.py @@ -0,0 +1,27 @@ +import os +import sys +sys.path.append(os.path.join(os.path.dirname(__file__), "..")) + +import numpy as np +import networkx as nx +from abc import ABC, abstractmethod +from typing import List +from representation import CGraph + + +""" Base class for graph kernels """ +class Kernel(ABC): + def __init__(self) -> None: + return + + @abstractmethod + def read_train_data(self, graphs: CGraph) -> None: + raise NotImplementedError + + @abstractmethod + def get_x(self, graphs: CGraph) -> np.array: + raise NotImplementedError + + @abstractmethod + def get_k(self, graphs: CGraph) -> np.array: + raise NotImplementedError \ No newline at end of file diff --git a/learner/kernels/wl.py b/learner/kernels/wl.py new file mode 100644 index 00000000..f3a951b8 --- /dev/null +++ b/learner/kernels/wl.py @@ -0,0 +1,121 @@ +import time +from .base_kernel import * + + +class WeisfeilerLehmanKernel(Kernel): + def __init__(self, iterations: int, all_colours: bool) -> None: + super().__init__() + + # hashes neighbour multisets of colours; same as self._representation if all_colours + self._hash = {} + + # option for returning only final WL iteration + self._representation = {} + + # number of wl iterations + self.iterations = iterations + + # collect colours from all iterations or only final + self.all_colours = all_colours + + def _get_hash_value(self, colour) -> int: + if colour not in self._hash: + self._hash[colour] = len(self._hash) + return self._hash[colour] + + def read_train_data(self, graphs: CGraph) -> None: + """ Read data and precompute the hash function """ + + t = time.time() + self._train_data_colours = {} + + # compute colours and hashmap from training data + for G in graphs: + cur_colours = {} + histogram = {} + + def store_colour(colour): + nonlocal histogram + if colour not in self._representation: + self._representation[colour] = len(self._representation) + if colour not in histogram: + histogram[colour] = 0 + histogram[colour] += 1 + + # collect initial colours + for u in G.nodes: + + # initial colour is feature of the node + colour = G.nodes[u]["colour"] + cur_colours[u] = self._get_hash_value(colour) + + # store histogram for all iterations or only last + if self.all_colours or self.iterations == 0: + store_colour(colour) + + # WL iterations + for itr in range(self.iterations): + new_colours = {} + for u in G.nodes: + + # edge label WL variant + neighbour_colours = [] + for v in G[u]: + colour_node = cur_colours[v] + colour_edge = G.edges[(u,v)]["edge_label"] + neighbour_colours.append((colour_node, colour_edge)) + neighbour_colours = sorted(neighbour_colours) + colour = tuple([cur_colours[u]] + neighbour_colours) + new_colours[u] = self._get_hash_value(colour) + + # store histogram for all iterations or only last + if self.all_colours or itr == self.iterations - 1: + store_colour(colour) + cur_colours = new_colours + + # store histogram of graph colours + self._train_data_colours[G] = histogram + + if self.all_colours: + self._representation = self._hash + + t = time.time() - t + print(f"Initialised WL for {len(graphs)} graphs in {t:.2f}s") + print(f"Collected {len(self._hash)} colours over {sum(len(G.nodes) for G in graphs)} nodes") + return + + def get_x(self, graphs: CGraph) -> np.array: + """ Explicit feature representation + O(nd) time; n x d output + """ + n = len(graphs) + d = len(self._representation) + X = np.zeros((n, d)) + for i, G in enumerate(graphs): + histogram = self._train_data_colours[G] + for colour in histogram: + j = self._representation[colour] + X[i][j] = histogram[colour] + return X + + def get_k(self, graphs: CGraph) -> np.array: + """ Implicit feature representation + O(n^2d) time; n x n output + """ + n = len(graphs) + K = np.zeros((n, n)) + for i in range(n): + for j in range(i, n): + k = 0 + + histogram_i = self._train_data_colours[graphs[i]] + histogram_j = self._train_data_colours[graphs[j]] + + common_colours = set(histogram_i.keys()).intersection(set(histogram_j.keys())) + for c in common_colours: + k += histogram_i[c] * histogram_j[c] + + K[i][j] = k + K[j][i] = k + return K + \ No newline at end of file diff --git a/learner/representation/__init__.py b/learner/representation/__init__.py index d4bd8fb9..f50de4bb 100644 --- a/learner/representation/__init__.py +++ b/learner/representation/__init__.py @@ -1,3 +1,4 @@ +from .base_class import CGraph, TGraph, Representation from .slg import StripsLearningGraph from .dlg import DeleteLearningGraph from .flg import FdrLearningGraph diff --git a/learner/representation/base_class.py b/learner/representation/base_class.py index 3c8c8d88..b0919e65 100644 --- a/learner/representation/base_class.py +++ b/learner/representation/base_class.py @@ -1,14 +1,12 @@ import sys -import matplotlib.pyplot as plt import torch import networkx as nx import copy import time -import torch.nn.functional as F -import signal import os import util import random +import hashlib from typing import Set, FrozenSet, List, NamedTuple, TypeVar, Tuple, Dict, Optional, Union from torch import Tensor @@ -26,8 +24,21 @@ from abc import ABC, abstractmethod from tqdm.auto import tqdm +# state is a list of facts represented as strings State = List[Proposition] +# graph representation represented as a tensor for GNNs +TGraph = Union[Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]] + +# graph representation represented as a nx.graph for graph kernels +CGraph = Union[nx.Graph, nx.DiGraph] + +# additional hard coded colours +ACTIVATED_COLOUR = "-1" +ACTIVATED_POS_GOAL_COLOUR_SUFFIX = "-pos-node" +ACTIVATED_NEG_GOAL_COLOUR_SUFFIX = "-neg-node" +IF_COLOUR_SUFFIX = "-if-index" + """ Base class for graph representations """ class Representation(ABC): @@ -63,6 +74,8 @@ def __init__(self, domain_pddl: str, problem_pddl: str) -> None: ) t = time.time() + self._pos_goal_nodes = set() + self._neg_goal_nodes = set() self._compute_graph_representation() self.num_nodes = len(self.G.nodes) self.num_edges = len(self.G.edges) @@ -98,7 +111,7 @@ def _dump_stats(self, start_time) -> None: return def convert_to_pyg(self) -> None: - """ Converts networkx graph object into pytorch_geometric tensors. + """ Converts nx graph into pytorch_geometric tensors and stores them. The tensors are (x, edge_index or edge_indices) x: torch.tensor(N x F) # N = num_nodes, F = num_features @@ -117,8 +130,8 @@ def convert_to_pyg(self) -> None: assert self.n_edge_labels > 1 self.edge_indices = [[] for _ in range(self.n_edge_labels)] edge_index_T = pyg_G.edge_index.T - for i, edge_type in enumerate(pyg_G.edge_type): - self.edge_indices[edge_type].append(edge_index_T[i]) + for i, edge_label in enumerate(pyg_G.edge_label): + self.edge_indices[edge_label].append(edge_index_T[i]) for i in range(self.n_edge_labels): if len(self.edge_indices[i]) > 0: self.edge_indices[i] = torch.vstack(self.edge_indices[i]).long().T @@ -126,10 +139,44 @@ def convert_to_pyg(self) -> None: self.edge_indices[i] = torch.tensor([[], []]).long() return + def convert_to_coloured_graph(self) -> None: + """ Converts nx graph into another nx graph but with colours instead of vector features. + + Vector features are converted to colours with a hash. This can be hardcoded slightly more + efficiently for each graph representation separately but takes more effort. + """ + + # TODO optimise by converting node string names into ints and storing the map + + colours = set() + + c_graph = self._create_graph() + for node in self.G.nodes: + feature = self.G.nodes[node]['x'].tolist() + feature = str(tuple(feature)) + if self.name == "llg" and type(node) == tuple and len(node)==2 and \ + type(node[1]) == str and "var-" in node[1]: + index = node[1].split('-')[-1] + colour = index+IF_COLOUR_SUFFIX + else: + colour = hashlib.sha256(feature.encode('utf-8')).hexdigest() + colours.add(colour) + c_graph.add_node(node, colour=colour) + for edge in self.G.edges: + u, v = edge + c_graph.add_edge(u_of_edge=u, v_of_edge=v, edge_label=self.G.edges[edge]["edge_label"]) + + self.c_graph = c_graph + return + @abstractmethod def _compute_graph_representation(self) -> None: raise NotImplementedError @abstractmethod - def get_state_enc(self, state: State): + def state_to_tensor(self, state: State) -> TGraph: + raise NotImplementedError + + @abstractmethod + def state_to_cgraph(self, state: State) -> CGraph: raise NotImplementedError diff --git a/learner/representation/dlg.py b/learner/representation/dlg.py index 5da5b993..e85e1b60 100644 --- a/learner/representation/dlg.py +++ b/learner/representation/dlg.py @@ -1,4 +1,4 @@ -from representation.base_class import * +from .base_class import * from representation.slg import StripsLearningGraph @@ -9,7 +9,7 @@ class DLG_FEATURES(Enum): STATE=3 -class DLG_EDGE_TYPES(Enum): +class DLG_EDGE_LABELS(Enum): PRE_EDGE=0 ADD_EDGE=1 @@ -17,7 +17,7 @@ class DLG_EDGE_TYPES(Enum): class DeleteLearningGraph(StripsLearningGraph, ABC): name = "dlg" n_node_features = len(DLG_FEATURES) - n_edge_labels = len(DLG_EDGE_TYPES) + n_edge_labels = len(DLG_EDGE_LABELS) directed = False lifted = False @@ -38,8 +38,10 @@ def _compute_graph_representation(self) -> None: # these features may get updated in state encoding if proposition in positive_goals: x_p = self._one_hot_node(DLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: x_p = self._one_hot_node(DLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p = self._zero_node() G.add_node(node_p, x=x_p) @@ -56,19 +58,19 @@ def _compute_graph_representation(self) -> None: p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=DLG_EDGE_TYPES.PRE_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=DLG_EDGE_LABELS.PRE_EDGE.value) for _, proposition in action.add_effects: # ignoring conditional effects p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=DLG_EDGE_TYPES.ADD_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=DLG_EDGE_LABELS.ADD_EDGE.value) """ Delete relaxation means ignoring delete edges """ # for _, proposition in action.del_effects: # ignoring conditional effects # p_node = self._proposition_to_str(proposition) # assert p_node in G.nodes, f"{p_node} not in nodes" # assert a_node in G.nodes, f"{a_node} not in nodes" - # G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=SDG_EDGE_TYPES.DEL_EDGE.value) + # G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=SDG_EDGE_LABELS.DEL_EDGE.value) # map node name to index self._node_to_i = {} @@ -78,7 +80,7 @@ def _compute_graph_representation(self) -> None: return - def get_state_enc(self, state: State) -> Tuple[Tensor, Tensor]: + def state_to_tensor(self, state: State) -> Tuple[Tensor, Tensor]: x = self.x.clone() # not time nor memory efficient, but no other way in Python for p in state: diff --git a/learner/representation/flg.py b/learner/representation/flg.py index a7bc5425..5c8c886f 100644 --- a/learner/representation/flg.py +++ b/learner/representation/flg.py @@ -1,4 +1,4 @@ -from representation.base_class import * +from .base_class import * class FLG_FEATURES(Enum): @@ -9,7 +9,7 @@ class FLG_FEATURES(Enum): ACTION=4 -class FLG_EDGE_TYPES(Enum): +class FLG_EDGE_LABELS(Enum): VV_EDGE=0 PRE_EDGE=1 EFF_EDGE=2 @@ -18,7 +18,7 @@ class FLG_EDGE_TYPES(Enum): class FdrLearningGraph(Representation, ABC): name = "flg" n_node_features = len(FLG_FEATURES) - n_edge_labels = len(FLG_EDGE_TYPES) + n_edge_labels = len(FLG_EDGE_LABELS) directed = False lifted = False @@ -52,7 +52,7 @@ def _compute_graph_representation(self) -> None: val_x += self._one_hot_node(FLG_FEATURES.GOAL.value) G.add_node(val_node, x=val_x) - G.add_edge(u_of_edge=var, v_of_edge=val_node, edge_type=FLG_EDGE_TYPES.VV_EDGE.value) + G.add_edge(u_of_edge=var, v_of_edge=val_node, edge_label=FLG_EDGE_LABELS.VV_EDGE.value) assert goals == len(goal) """ action nodes and edges """ @@ -63,13 +63,13 @@ def _compute_graph_representation(self) -> None: assert val in variables[var] # and hence should be in G.nodes() val_node = (var, val) assert val_node in G.nodes() - G.add_edge(u_of_edge=action_node, v_of_edge=val_node, edge_type=FLG_EDGE_TYPES.PRE_EDGE.value) + G.add_edge(u_of_edge=action_node, v_of_edge=val_node, edge_label=FLG_EDGE_LABELS.PRE_EDGE.value) for var, val in action.add_effects: # from our compilation, effects are in add only assert val in variables[var] val_node = (var, val) assert val_node in G.nodes() - G.add_edge(u_of_edge=action_node, v_of_edge=val_node, edge_type=FLG_EDGE_TYPES.EFF_EDGE.value) + G.add_edge(u_of_edge=action_node, v_of_edge=val_node, edge_label=FLG_EDGE_LABELS.EFF_EDGE.value) # map node name to index node_to_i = {} @@ -82,7 +82,7 @@ def _compute_graph_representation(self) -> None: return - def get_state_enc(self, state: State) -> Tuple[Tensor, Tensor]: + def state_to_tensor(self, state: State) -> Tuple[Tensor, Tensor]: x = self.x.clone() for p in state: diff --git a/learner/representation/glg.py b/learner/representation/glg.py index efffac6d..b8dfc981 100644 --- a/learner/representation/glg.py +++ b/learner/representation/glg.py @@ -1,5 +1,5 @@ -from representation.base_class import * -from representation.slg import StripsLearningGraph +from .base_class import * +from .slg import StripsLearningGraph class GLG_FEATURES(Enum): @@ -11,7 +11,7 @@ class GLG_FEATURES(Enum): SCHEMA=5 -class GLG_EDGE_TYPES(Enum): +class GLG_EDGE_LABELS(Enum): PRE_EDGE=0 ADD_EDGE=1 DEL_EDGE=2 @@ -21,7 +21,7 @@ class GLG_EDGE_TYPES(Enum): class GroundedLearningGraph(StripsLearningGraph, ABC): name = "glg" n_node_features = len(GLG_FEATURES) - n_edge_labels = len(GLG_EDGE_TYPES) + n_edge_labels = len(GLG_EDGE_LABELS) directed = False lifted = False @@ -41,9 +41,11 @@ def _compute_graph_representation(self) -> None: node_p = self._proposition_to_str(proposition) # these features may get updated in state encoding if proposition in positive_goals: - x_p=self._one_hot_node(GLG_FEATURES.POSITIVE_GOAL.value) + x_p = self._one_hot_node(GLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: - x_p=self._one_hot_node(GLG_FEATURES.NEGATIVE_GOAL.value) + x_p = self._one_hot_node(GLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p=self._zero_node() G.add_node(node_p, x=x_p) @@ -65,24 +67,24 @@ def _compute_graph_representation(self) -> None: s_node = self._get_predicate_from_action(action) assert a_node in G.nodes assert s_node in G.nodes - G.add_edge(u_of_edge=a_node, v_of_edge=s_node, edge_type=GLG_EDGE_TYPES.PREDICATE.value) + G.add_edge(u_of_edge=a_node, v_of_edge=s_node, edge_label=GLG_EDGE_LABELS.PREDICATE.value) # edges between actions and propositions for proposition in action.precondition: p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=GLG_EDGE_TYPES.PRE_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=GLG_EDGE_LABELS.PRE_EDGE.value) for _, proposition in action.add_effects: # ignoring conditional effects p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=GLG_EDGE_TYPES.ADD_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=GLG_EDGE_LABELS.ADD_EDGE.value) for _, proposition in action.del_effects: # ignoring conditional effects p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=GLG_EDGE_TYPES.DEL_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=GLG_EDGE_LABELS.DEL_EDGE.value) for proposition in propositions: # edge between propositions and predicates @@ -90,7 +92,7 @@ def _compute_graph_representation(self) -> None: pred_node = self._get_predicate_from_proposition(proposition) assert p_node in G.nodes assert pred_node in G.nodes - G.add_edge(u_of_edge=p_node, v_of_edge=pred_node, edge_type=GLG_EDGE_TYPES.PREDICATE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=pred_node, edge_label=GLG_EDGE_LABELS.PREDICATE.value) # map node name to index self._node_to_i = {} @@ -100,7 +102,7 @@ def _compute_graph_representation(self) -> None: return - def get_state_enc(self, state: State) -> Tuple[Tensor, Tensor]: + def state_to_tensor(self, state: State) -> Tuple[Tensor, Tensor]: x = self.x.clone() # not time nor memory efficient, but no other way in Python for p in state: diff --git a/learner/representation/llg.py b/learner/representation/llg.py index 841ee0ab..215f40e3 100644 --- a/learner/representation/llg.py +++ b/learner/representation/llg.py @@ -1,5 +1,6 @@ +from .base_class import * from planning.translate.pddl import Atom, NegatedAtom, Truth -from representation.base_class import * + class LLG_FEATURES(Enum): P=0 # is predicate @@ -13,8 +14,7 @@ class LLG_FEATURES(Enum): ENC_FEAT_SIZE = len(LLG_FEATURES) VAR_FEAT_SIZE = 4 - -LLG_EDGE_TYPES = OrderedDict({ +LLG_EDGE_LABELS = OrderedDict({ "neutral": 0, "ground": 1, "pre_pos": 2, @@ -22,13 +22,12 @@ class LLG_FEATURES(Enum): "eff_pos": 4, "eff_neg": 5, }) - class LiftedLearningGraph(Representation, ABC): name = "llg" n_node_features = ENC_FEAT_SIZE+VAR_FEAT_SIZE - n_edge_labels = len(LLG_EDGE_TYPES) + n_edge_labels = len(LLG_EDGE_LABELS) directed = False lifted = True @@ -37,14 +36,18 @@ def __init__(self, domain_pddl: str, problem_pddl: str): def _construct_if(self) -> None: """ Precompute a seeded randomly generated injective index function """ - self._pe = [] + self._if = [] + image = set() # check injectiveness # TODO read max range from problem and lazily compute for idx in range(60): torch.manual_seed(idx) rep = 2*torch.rand(VAR_FEAT_SIZE)-1 # U[-1,1] rep /= torch.linalg.norm(rep) - self._pe.append(rep) + self._if.append(rep) + key = tuple(rep.tolist()) + assert key not in image + image.add(key) return def _feature(self, node_type: LLG_FEATURES) -> Tensor: @@ -54,7 +57,7 @@ def _feature(self, node_type: LLG_FEATURES) -> Tensor: def _if_feature(self, idx: int) -> Tensor: ret = torch.zeros(self.n_node_features) - ret[-VAR_FEAT_SIZE:] = self._pe[idx] + ret[-VAR_FEAT_SIZE:] = self._if[idx] return ret def _compute_graph_representation(self) -> None: @@ -80,10 +83,10 @@ def _compute_graph_representation(self) -> None: # fully connected between objects and predicates for pred in self.problem.predicates: for obj in self.problem.objects: - G.add_edge(u_of_edge=pred.name, v_of_edge=obj.name, edge_type=LLG_EDGE_TYPES["neutral"]) + G.add_edge(u_of_edge=pred.name, v_of_edge=obj.name, edge_label=LLG_EDGE_LABELS["neutral"]) - # goal (state gets dealt with in get_state_enc) + # goal (state gets dealt with in state_to_tensor) if len(self.problem.goal.parts) == 0: goals = [self.problem.goal] else: @@ -100,8 +103,10 @@ def _compute_graph_representation(self) -> None: if is_negated: x = self._feature(LLG_FEATURES.N) - else: + self._neg_goal_nodes.add(goal_node) + else: x = self._feature(LLG_FEATURES.G) + self._pos_goal_nodes.add(goal_node) G.add_node(goal_node, x=x) # add grounded predicate node for i, arg in enumerate(args): @@ -109,15 +114,15 @@ def _compute_graph_representation(self) -> None: G.add_node(goal_var_node, x=self._if_feature(idx=i)) # connect variable to predicate - G.add_edge(u_of_edge=goal_node, v_of_edge=goal_var_node, edge_type=LLG_EDGE_TYPES["ground"]) + G.add_edge(u_of_edge=goal_node, v_of_edge=goal_var_node, edge_label=LLG_EDGE_LABELS["ground"]) # connect variable to object assert arg in G.nodes() - G.add_edge(u_of_edge=goal_var_node, v_of_edge=arg, edge_type=LLG_EDGE_TYPES["ground"]) + G.add_edge(u_of_edge=goal_var_node, v_of_edge=arg, edge_label=LLG_EDGE_LABELS["ground"]) # connect grounded fact to predicate assert pred in G.nodes() - G.add_edge(u_of_edge=goal_node, v_of_edge=pred, edge_type=LLG_EDGE_TYPES["ground"]) + G.add_edge(u_of_edge=goal_node, v_of_edge=pred, edge_label=LLG_EDGE_LABELS["ground"]) # end goal @@ -132,28 +137,28 @@ def _compute_graph_representation(self) -> None: arg_node = (action.name, f"action-var-{i}") # action var G.add_node(arg_node, x=self._if_feature(idx=i)) action_args[arg.name] = arg_node - G.add_edge(u_of_edge=action.name, v_of_edge=arg_node, edge_type=LLG_EDGE_TYPES["neutral"]) + G.add_edge(u_of_edge=action.name, v_of_edge=arg_node, edge_label=LLG_EDGE_LABELS["neutral"]) - def deal_with_action_prec_or_eff(predicates, edge_type): + def deal_with_action_prec_or_eff(predicates, edge_label): for z, predicate in enumerate(predicates): pred = predicate.predicate - aux_node = (pred, f"{edge_type}-aux-{z}") # aux node for duplicate preds + aux_node = (pred, f"{edge_label}-aux-{z}") # aux node for duplicate preds G.add_node(aux_node, x=self._zero_node()) assert pred in G.nodes() - G.add_edge(u_of_edge=pred, v_of_edge=aux_node, edge_type=LLG_EDGE_TYPES[edge_type]) + G.add_edge(u_of_edge=pred, v_of_edge=aux_node, edge_label=LLG_EDGE_LABELS[edge_label]) if len(predicate.args) > 0: for j, arg in enumerate(predicate.args): - prec_arg_node = (arg, f"{edge_type}-aux-{z}-var-{j}") # aux var + prec_arg_node = (arg, f"{edge_label}-aux-{z}-var-{j}") # aux var G.add_node(prec_arg_node, x=self._if_feature(idx=j)) - G.add_edge(u_of_edge=aux_node, v_of_edge=prec_arg_node, edge_type=LLG_EDGE_TYPES[edge_type]) + G.add_edge(u_of_edge=aux_node, v_of_edge=prec_arg_node, edge_label=LLG_EDGE_LABELS[edge_label]) if arg in action_args: action_arg_node = action_args[arg] - G.add_edge(u_of_edge=prec_arg_node, v_of_edge=action_arg_node, edge_type=LLG_EDGE_TYPES[edge_type]) + G.add_edge(u_of_edge=prec_arg_node, v_of_edge=action_arg_node, edge_label=LLG_EDGE_LABELS[edge_label]) else: # unitary predicate so connect directly to action - G.add_edge(u_of_edge=aux_node, v_of_edge=action.name, edge_type=LLG_EDGE_TYPES[edge_type]) + G.add_edge(u_of_edge=aux_node, v_of_edge=action.name, edge_label=LLG_EDGE_LABELS[edge_label]) return pos_pres = [p for p in action.precondition.parts if type(p)==Atom] @@ -193,7 +198,7 @@ def str_to_state(self, s) -> List[Tuple[str, List[str]]]: state.append((toks[0], ())) return state - def get_state_enc(self, state: List[Tuple[str, List[str]]]) -> Tuple[Tensor, Tensor]: + def state_to_tensor(self, state: List[Tuple[str, List[str]]]) -> TGraph: """ States are represented as a list of (pred, [args]) """ x = self.x.clone() edge_indices = self.edge_indices.copy() @@ -226,7 +231,7 @@ def get_state_enc(self, state: List[Tuple[str, List[str]]]) -> Tuple[Tensor, Ten # connect to predicates and objects for k, arg in enumerate(args): true_var_node_i = i - x[i][-VAR_FEAT_SIZE:] = self._pe[k] + x[i][-VAR_FEAT_SIZE:] = self._if[k] i += 1 # connect variable to predicate @@ -237,8 +242,48 @@ def get_state_enc(self, state: List[Tuple[str, List[str]]]) -> Tuple[Tensor, Ten append_edge_index.append((true_var_node_i, self._node_to_i[arg])) append_edge_index.append((self._node_to_i[arg], true_var_node_i)) - edge_indices[LLG_EDGE_TYPES["ground"]] = torch.hstack((edge_indices[LLG_EDGE_TYPES["ground"]], + edge_indices[LLG_EDGE_LABELS["ground"]] = torch.hstack((edge_indices[LLG_EDGE_LABELS["ground"]], torch.tensor(append_edge_index).T)).long() return x, edge_indices + + def state_to_cgraph(self, state: List[Tuple[str, List[str]]]) -> CGraph: + """ States are represented as a list of (pred, [args]) """ + c_graph = self.c_graph.copy() + + for fact in state: + pred = fact[0] + args = fact[1] + + node = (pred, tuple(args)) + + # activated proposition overlaps with a goal Atom or NegatedAtom + if node in self._pos_goal_nodes: + c_graph.nodes[node]['colour'] = c_graph.nodes[node]['colour']+ACTIVATED_POS_GOAL_COLOUR_SUFFIX + continue + elif node in self._neg_goal_nodes: + c_graph.nodes[node]['colour'] = c_graph.nodes[node]['colour']+ACTIVATED_NEG_GOAL_COLOUR_SUFFIX + continue + + # else add node and corresponding edges to graph + c_graph.add_node(node, colour=ACTIVATED_COLOUR) + + # connect fact to predicate + c_graph.add_edge(u_of_edge=node, v_of_edge=pred, edge_label=LLG_EDGE_LABELS["ground"]) + c_graph.add_edge(v_of_edge=node, u_of_edge=pred, edge_label=LLG_EDGE_LABELS["ground"]) + + # connect to predicates and objects + for k, arg in enumerate(args): + arg_node = (node, f"true-var-{k}") + c_graph.add_node(arg_node, colour=str(k)+IF_COLOUR_SUFFIX) + + # connect variable to predicate + c_graph.add_edge(u_of_edge=node, v_of_edge=arg_node, edge_label=LLG_EDGE_LABELS["ground"]) + c_graph.add_edge(v_of_edge=node, u_of_edge=arg_node, edge_label=LLG_EDGE_LABELS["ground"]) + + # connect variable to object + c_graph.add_edge(u_of_edge=arg_node, v_of_edge=arg, edge_label=LLG_EDGE_LABELS["ground"]) + c_graph.add_edge(v_of_edge=arg_node, u_of_edge=arg, edge_label=LLG_EDGE_LABELS["ground"]) + + return c_graph \ No newline at end of file diff --git a/learner/representation/slg.py b/learner/representation/slg.py index 358682d2..2159fa12 100644 --- a/learner/representation/slg.py +++ b/learner/representation/slg.py @@ -1,4 +1,4 @@ -from representation.base_class import * +from .base_class import * from planning.translate.pddl import Literal, Atom, NegatedAtom, PropositionalAction @@ -9,7 +9,7 @@ class SLG_FEATURES(Enum): STATE=3 -class SLG_EDGE_TYPES(Enum): +class SLG_EDGE_LABELS(Enum): PRE_EDGE=0 ADD_EDGE=1 DEL_EDGE=2 @@ -19,7 +19,7 @@ class SLG_EDGE_TYPES(Enum): class StripsLearningGraph(Representation, ABC): name = "slg" n_node_features = len(SLG_FEATURES) - n_edge_labels = len(SLG_EDGE_TYPES) + n_edge_labels = len(SLG_EDGE_LABELS) directed = False lifted = False @@ -89,8 +89,10 @@ def _compute_graph_representation(self) -> None: # these features may get updated in state encoding if proposition in positive_goals: x_p = self._one_hot_node(SLG_FEATURES.POSITIVE_GOAL.value) + self._pos_goal_nodes.add(node_p) elif proposition in negative_goals: x_p = self._one_hot_node(SLG_FEATURES.NEGATIVE_GOAL.value) + self._neg_goal_nodes.add(node_p) else: x_p = self._zero_node() G.add_node(node_p, x=x_p) @@ -107,17 +109,17 @@ def _compute_graph_representation(self) -> None: p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=SLG_EDGE_TYPES.PRE_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=SLG_EDGE_LABELS.PRE_EDGE.value) for _, proposition in action.add_effects: # ignoring conditional effects p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=SLG_EDGE_TYPES.ADD_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=SLG_EDGE_LABELS.ADD_EDGE.value) for _, proposition in action.del_effects: # ignoring conditional effects p_node = self._proposition_to_str(proposition) assert p_node in G.nodes, f"{p_node} not in nodes" assert a_node in G.nodes, f"{a_node} not in nodes" - G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_type=SLG_EDGE_TYPES.DEL_EDGE.value) + G.add_edge(u_of_edge=p_node, v_of_edge=a_node, edge_label=SLG_EDGE_LABELS.DEL_EDGE.value) # map node name to index self._node_to_i = {} @@ -127,7 +129,7 @@ def _compute_graph_representation(self) -> None: return - def get_state_enc(self, state: State) -> Tuple[Tensor, Tensor]: + def state_to_tensor(self, state: State) -> Tuple[Tensor, Tensor]: x = self.x.clone() # not time nor memory efficient, but no other way in Python for p in state: @@ -135,3 +137,18 @@ def get_state_enc(self, state: State) -> Tuple[Tensor, Tensor]: x[self._node_to_i[p]][SLG_FEATURES.STATE.value] = 1 return x, self.edge_indices + + def state_to_cgraph(self, state: State) -> CGraph: + """ States are represented as a list of (pred, [args]) """ + c_graph = self.c_graph.copy() + + for p in state: + + # activated proposition overlaps with a goal Atom or NegatedAtom + if p in self._pos_goal_nodes: + c_graph.nodes[p]['colour'] = c_graph.nodes[p]['colour']+ACTIVATED_POS_GOAL_COLOUR_SUFFIX + elif p in self._neg_goal_nodes: + c_graph.nodes[p]['colour'] = c_graph.nodes[p]['colour']+ACTIVATED_NEG_GOAL_COLOUR_SUFFIX + + return c_graph + \ No newline at end of file diff --git a/learner/run_gnn.py b/learner/run_gnn.py index 31330494..8882002d 100644 --- a/learner/run_gnn.py +++ b/learner/run_gnn.py @@ -32,4 +32,5 @@ seed=0, ) + print(cmd) os.system(cmd) diff --git a/learner/scripts/generate_all_graphs.sh b/learner/scripts/generate_all_graphs.sh deleted file mode 100644 index 1ad626ef..00000000 --- a/learner/scripts/generate_all_graphs.sh +++ /dev/null @@ -1,5 +0,0 @@ -for rep in ldg-el fdg-el sdg-el gdg-el -do - echo "python3 scripts/generate_graphs.py $rep --regenerate" - python3 scripts/generate_graphs.py $rep --regenerate -done diff --git a/learner/scripts/.gitignore b/learner/scripts_gnn/.gitignore similarity index 100% rename from learner/scripts/.gitignore rename to learner/scripts_gnn/.gitignore diff --git a/learner/scripts/cluster1_job_3090 b/learner/scripts_gnn/cluster1_job_3090 similarity index 100% rename from learner/scripts/cluster1_job_3090 rename to learner/scripts_gnn/cluster1_job_3090 diff --git a/learner/scripts/cluster1_job_a6000 b/learner/scripts_gnn/cluster1_job_a6000 similarity index 100% rename from learner/scripts/cluster1_job_a6000 rename to learner/scripts_gnn/cluster1_job_a6000 diff --git a/learner/scripts/cluster1_job_any b/learner/scripts_gnn/cluster1_job_any similarity index 100% rename from learner/scripts/cluster1_job_any rename to learner/scripts_gnn/cluster1_job_any diff --git a/learner/scripts/cluster1_job_planopt b/learner/scripts_gnn/cluster1_job_planopt similarity index 100% rename from learner/scripts/cluster1_job_planopt rename to learner/scripts_gnn/cluster1_job_planopt diff --git a/learner/scripts/collect_cluster1_logs.sh b/learner/scripts_gnn/collect_cluster1_logs.sh similarity index 100% rename from learner/scripts/collect_cluster1_logs.sh rename to learner/scripts_gnn/collect_cluster1_logs.sh diff --git a/learner/scripts/predict_dd_and_di.py b/learner/scripts_gnn/predict_dd_and_di.py similarity index 100% rename from learner/scripts/predict_dd_and_di.py rename to learner/scripts_gnn/predict_dd_and_di.py diff --git a/learner/scripts/submit_dd_train_only.sh b/learner/scripts_gnn/submit_dd_train_only.sh similarity index 100% rename from learner/scripts/submit_dd_train_only.sh rename to learner/scripts_gnn/submit_dd_train_only.sh diff --git a/learner/scripts/submit_dd_train_validate_test.sh b/learner/scripts_gnn/submit_dd_train_validate_test.sh similarity index 100% rename from learner/scripts/submit_dd_train_validate_test.sh rename to learner/scripts_gnn/submit_dd_train_validate_test.sh diff --git a/learner/scripts/submit_di_train_only.sh b/learner/scripts_gnn/submit_di_train_only.sh similarity index 100% rename from learner/scripts/submit_di_train_only.sh rename to learner/scripts_gnn/submit_di_train_only.sh diff --git a/learner/scripts/submit_di_train_validate_test.sh b/learner/scripts_gnn/submit_di_train_validate_test.sh similarity index 100% rename from learner/scripts/submit_di_train_validate_test.sh rename to learner/scripts_gnn/submit_di_train_validate_test.sh diff --git a/learner/scripts/submit_predict.sh b/learner/scripts_gnn/submit_predict.sh similarity index 100% rename from learner/scripts/submit_predict.sh rename to learner/scripts_gnn/submit_predict.sh diff --git a/learner/scripts/train_validate_test_dd.py b/learner/scripts_gnn/train_validate_test_dd.py similarity index 100% rename from learner/scripts/train_validate_test_dd.py rename to learner/scripts_gnn/train_validate_test_dd.py diff --git a/learner/scripts/train_validate_test_di.py b/learner/scripts_gnn/train_validate_test_di.py similarity index 100% rename from learner/scripts/train_validate_test_di.py rename to learner/scripts_gnn/train_validate_test_di.py diff --git a/learner/scripts_kernel/cross_validate_all.sh b/learner/scripts_kernel/cross_validate_all.sh new file mode 100644 index 00000000..a628aa2c --- /dev/null +++ b/learner/scripts_kernel/cross_validate_all.sh @@ -0,0 +1,18 @@ +LOG_DIR=logs/train_kernel + +mkdir -p $LOG_DIR + +for l in 0 1 2 3 4 +do + for k in wl + do + for r in llg slg dlg glg + do + for d in gripper spanner visitall visitsome blocks ferry sokoban n-puzzle + do + echo $r $k $l $d + python3 train_kernel.py -k $k -l $l -r $r -d $d --visualise --cross-validation > $LOG_DIR/${r}_${d}_${k}_${l}.log + done + done + done +done \ No newline at end of file diff --git a/learner/scripts_kernel/train_all.sh b/learner/scripts_kernel/train_all.sh new file mode 100644 index 00000000..3cade11c --- /dev/null +++ b/learner/scripts_kernel/train_all.sh @@ -0,0 +1,18 @@ +LOG_DIR=logs/train_kernel + +mkdir -p $LOG_DIR + +for l in 0 1 2 3 4 +do + for k in wl + do + for r in llg slg dlg glg + do + for d in gripper spanner visitall visitsome blocks ferry sokoban n-puzzle + do + echo $r $k $l $d + python3 train_kernel.py -k $k -l $l -r $r -d $d --save-file ${r}_${d}_${k}_${l} > $LOG_DIR/${r}_${d}_${k}_${l}.log + done + done + done +done \ No newline at end of file diff --git a/learner/test_gnn.sh b/learner/test_gnn.sh index c773d86d..a0a53ffb 100644 --- a/learner/test_gnn.sh +++ b/learner/test_gnn.sh @@ -1 +1,2 @@ -singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_llg_gripper.dt -r llg \ No newline at end of file +# singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_llg_gripper.dt -r llg +singularity exec --nv ../gpu.sif python3 run_gnn.py ../benchmarks/goose/gripper/domain.pddl ../benchmarks/goose/gripper/test/gripper-n20.pddl -m saved_models/dd_slg_gripper.dt -r slg \ No newline at end of file diff --git a/learner/train_gnn.py b/learner/train_gnn.py index e5cf904d..24a6a57b 100755 --- a/learner/train_gnn.py +++ b/learner/train_gnn.py @@ -3,25 +3,25 @@ import time import torch import argparse -import models +import gnns import representation -from models import * +from gnns import * from tqdm.auto import tqdm, trange from util.stats import * from util.save_load import * from util import train, evaluate -from dataset.dataset import get_loaders_from_args +from dataset.dataset import get_loaders_from_args_gnn def create_parser(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=int, default=0) - parser.add_argument('-d', '--domain', default="goose-pretraining") + parser.add_argument('-d', '--domain', default="goose-di") parser.add_argument('-t', '--task', default='h', choices=["h", "a"], help="predict value or action (currently only h is supported)") # model params - parser.add_argument('-m', '--model', type=str, required=True, choices=models.GNNS) + parser.add_argument('-m', '--model', type=str, required=True, choices=gnns.GNNS) parser.add_argument('-L', '--nlayers', type=int, default=16) parser.add_argument('-H', '--nhid', type=int, default=64) parser.add_argument('--share-layers', action='store_true') @@ -72,7 +72,7 @@ def check_config(args): device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') # init model - train_loader, val_loader = get_loaders_from_args(args) + train_loader, val_loader = get_loaders_from_args_gnn(args) args.n_edge_labels = representation.REPRESENTATIONS[args.rep].n_edge_labels args.in_feat = train_loader.dataset[0].x.shape[1] model_params = arg_to_params(args) diff --git a/learner/train_kernel.py b/learner/train_kernel.py new file mode 100755 index 00000000..3ca03f41 --- /dev/null +++ b/learner/train_kernel.py @@ -0,0 +1,177 @@ +""" Main training pipeline script. """ + +import os +import time +import argparse +import representation +import kernels +import numpy as np +from dataset.dataset import get_dataset_from_args_kernels +from util.save_load import print_arguments, save_sklearn_model +from util.metrics import f1_macro +from util.visualise import get_confusion_matrix +from sklearn.svm import LinearSVR, SVR +from sklearn.model_selection import cross_validate +from sklearn.metrics import make_scorer, mean_squared_error + +import warnings +warnings.filterwarnings('ignore') + + +_MODELS = [ + "linear-svr", + "svr", +] + +_CV_FOLDS = 5 +_MAX_MODEL_ITER = 10000 +_PLOT_DIR = "plots" +_SCORING = { + "mse": make_scorer(mean_squared_error), + "f1_macro": make_scorer(f1_macro) +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('-r', '--rep', type=str, required=True, choices=representation.REPRESENTATIONS, + help="graph representation to use") + + parser.add_argument('-k', '--kernel', type=str, required=True, choices=kernels.KERNELS, + help="graph representation to use") + parser.add_argument('-l', '--iterations', type=int, default=5, + help="number of iterations for kernel algorithms") + parser.add_argument('--final-only', dest="all_colours", action="store_false", + help="collects colours from only final iteration of WL kernels") + + parser.add_argument('-m', '--model', type=str, default="linear-svr", choices=_MODELS, + help="ML model") + parser.add_argument('-C', type=float, default=1, + help="regularisation parameter of SVR; strength is inversely proportional to C") + parser.add_argument('-e', type=float, default=0.1, + help="epsilon parameter in epsilon insensitive loss function of SVR") + + parser.add_argument('-d', '--domain', type=str, default="goose-di", + help="domain to train on; defaults to goose-di which is di training") + + parser.add_argument('-s', '--seed', type=int, default=0, + help="random seed") + + parser.add_argument('--cross-validation', action='store_true', + help="performs cross validation scoring; otherwise train on whole dataset") + parser.add_argument('--save-file', type=str, default=None, + help="save file of model weights when not using --cross-validation") + parser.add_argument('--visualise', action='store_true', + help="visualise train and test predictions; only used with --cross-validation") + parser.add_argument('--small-train', action="store_true", + help="use small train set, useful for debugging") + + return parser.parse_args() + +def perform_training(X, y, model, args): + print(f"Training on entire {args.domain} for {model_name}...") + t = time.time() + model.fit(X, y) + print(f"Model training completed in {time.time()-t:.2f}s") + for metric in _SCORING: + print(f"train_{metric}: {_SCORING[metric](model, X, y):.2f}") + save_sklearn_model(model, args) + return + +def perform_cross_validation(X, y, model, args): + print(f"Performing {_CV_FOLDS}-fold cross validation on {model_name}...") + t = time.time() + scores = cross_validate( + model, X, y, + cv=_CV_FOLDS, scoring=_SCORING, return_train_score=True, n_jobs=-1, + return_estimator=args.visualise, return_indices=args.visualise, + ) + print(f"CV completed in {time.time() - t:.2f}s") + + for metric in _SCORING: + train_key = f"train_{metric}" + test_key = f"test_{metric}" + print(f"train_{metric}: {scores[train_key].mean():.2f} ± {scores[train_key].std():.2f}") + print(f"test_{metric}: {scores[test_key].mean():.2f} ± {scores[test_key].std():.2f}") + + if args.visualise: + """ Visualise predictions and save to file + Performs some redundant computations + """ + + if model_name == "svr": # kernel matrix case + raise NotImplementedError + + print("Saving visualisation...") + train_trues = [] + train_preds = [] + test_trues = [] + test_preds = [] + + for i in range(_CV_FOLDS): + estimator = scores["estimator"][i] + train_indices = scores["indices"]["train"][i] + test_indices = scores["indices"]["test"][i] + X_train = X[train_indices] + X_test = X[test_indices] + y_train = y[train_indices] + y_test = y[test_indices] + train_pred = estimator.predict(X_train) + test_pred = estimator.predict(X_test) + train_trues.append(y_train) + train_preds.append(train_pred) + test_trues.append(y_test) + test_preds.append(test_pred) + + y_true_train = np.concatenate(train_trues) + y_pred_train = np.concatenate(train_preds) + y_true_test = np.concatenate(test_trues) + y_pred_test = np.concatenate(test_preds) + + plt = get_confusion_matrix(y_true_train, y_pred_train, y_true_test, y_pred_test) + + os.makedirs(_PLOT_DIR, exist_ok=True) + file_name = _PLOT_DIR + "/" + "_".join([args.domain, args.rep, args.kernel, str(args.iterations)]) + ".pdf" + plt.savefig(file_name, bbox_inches="tight") + print(f"Visualisation saved at {file_name}") + return + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + + np.random.seed(args.seed) + + print(f"Initialising {args.kernel}...") + graphs, y = get_dataset_from_args_kernels(args) + kernel = kernels.KERNELS[args.kernel]( + iterations=args.iterations, + all_colours=args.all_colours, + ) + kernel.read_train_data(graphs) + + print(f"Setting up training data and initialising model...") + model_name = args.model + t = time.time() + + kwargs = { + "epsilon": args.e, + "C": args.C, + "max_iter": _MAX_MODEL_ITER, + } + if model_name == "linear-svr": + model = LinearSVR(dual="auto", **kwargs) + X = kernel.get_x(graphs) + elif model_name == "svr": + model = SVR(kernel="precomputed", **kwargs) + X = kernel.get_k(graphs) + else: + raise NotImplementedError + print(f"Set up training data in {time.time()-t:.2f}s") + + if args.cross_validation: + perform_cross_validation(X, y, model, args) + else: + perform_training(X, y, model, args) \ No newline at end of file diff --git a/learner/util/metrics.py b/learner/util/metrics.py index 20116b82..bbbcc374 100644 --- a/learner/util/metrics.py +++ b/learner/util/metrics.py @@ -8,6 +8,11 @@ """ Module containing metrics for inference only. """ +def f1_macro(y_true: np.array, y_pred: np.array) -> float: + y_true = np.rint(y_true).astype(int) + y_pred = np.rint(y_pred).astype(int) + return f1_score(y_true, y_pred, average='macro') + @torch.no_grad() def eval_accuracy(y_pred: Tensor, y_true: Tensor): try: diff --git a/learner/util/save_load.py b/learner/util/save_load.py index 85c4ebf7..0ac282f6 100644 --- a/learner/util/save_load.py +++ b/learner/util/save_load.py @@ -1,111 +1,133 @@ +""" Module for dealing with model saving and loading. """ import os import torch +import joblib import datetime import representation from argparse import Namespace as Args from typing import Tuple -from models.base_gnn import BasePredictor as GNN -from models import * +from gnns.base_gnn import BasePredictor as GNN +from gnns import * -""" Module for dealing with model saving and loading. """ + +_TRAINED_MODELS_SAVE_DIR = "trained_models" +os.makedirs(_TRAINED_MODELS_SAVE_DIR, exist_ok=True) def arg_to_params(args, in_feat=4, out_feat=1): - model_name = args.model - nlayers = args.nlayers - nhid = args.nhid - in_feat = args.in_feat - n_edge_labels = args.n_edge_labels - share_layers = args.share_layers - task = args.task - pool = args.pool - aggr = args.aggr - vn = args.vn - rep = args.rep - model_params = { - 'model_name': model_name, - 'in_feat': in_feat, - 'out_feat': out_feat, - 'nlayers': nlayers, - 'share_layers': share_layers, - 'n_edge_labels': n_edge_labels, - 'nhid': nhid, - 'aggr': aggr, - 'pool': pool, - 'task': task, - 'rep': rep, - 'vn': vn, - } - return model_params + model_name = args.model + nlayers = args.nlayers + nhid = args.nhid + in_feat = args.in_feat + n_edge_labels = args.n_edge_labels + share_layers = args.share_layers + task = args.task + pool = args.pool + aggr = args.aggr + vn = args.vn + rep = args.rep + model_params = { + 'model_name': model_name, + 'in_feat': in_feat, + 'out_feat': out_feat, + 'nlayers': nlayers, + 'share_layers': share_layers, + 'n_edge_labels': n_edge_labels, + 'nhid': nhid, + 'aggr': aggr, + 'pool': pool, + 'task': task, + 'rep': rep, + 'vn': vn, + } + return model_params def print_arguments(args, ignore_params=set()): - if hasattr(args, 'pretrained') and args.pretrained is not None: - return - print("Parsed arguments:") - for k, v in vars(args).items(): - if k in ignore_params.union({"device", "optimal", "save_model", "save_file", "no_tqdm", "tqdm", "fast_train"}): - continue - print('{0:20} {1}'.format(k, v)) + if hasattr(args, 'pretrained') and args.pretrained is not None: + return + print("Parsed arguments:") + for k, v in vars(args).items(): + if k in ignore_params.union({"device", "optimal", "save_model", "save_file", "no_tqdm", "tqdm", "fast_train"}): + continue + print('{0:20} {1}'.format(k, v)) def save_model_from_dict(model_dict, args): - if not hasattr(args, "save_file") or args.save_file is None: - return - print("Saving model...") - save_dir = 'trained_models' - os.makedirs(f"{save_dir}/", exist_ok=True) - model_file_name = args.save_file.replace(".dt", "") - path = f'{save_dir}/{model_file_name}.dt' - torch.save((model_dict, args), path) - print("Model saved!") - print("Model parameter file:") - print(model_file_name) + if not hasattr(args, "save_file") or args.save_file is None: return + print("Saving model...") + model_file_name = args.save_file.replace(".dt", "") + path = f'{_TRAINED_MODELS_SAVE_DIR}/{model_file_name}.dt' + torch.save((model_dict, args), path) + print("Model saved!") + print("Model parameter file:") + print(model_file_name) + return def save_model(model, args): - save_model_from_dict(model.model.state_dict(), args) + save_model_from_dict(model.model.state_dict(), args) + return + + +def save_sklearn_model(model, args): + if not hasattr(args, "save_file") or args.save_file is None: return + print("Saving model...") + model_file_name = args.save_file.replace(".joblib", "") + path = f'{_TRAINED_MODELS_SAVE_DIR}/{model_file_name}.joblib' + joblib.dump((model, args), path) + print("Model saved!") + print("Model parameter file:") + print(model_file_name) + return + + +def load_sklearn_model(path, ignore_subdir=False): + if not ignore_subdir and _TRAINED_MODELS_SAVE_DIR not in path: + path = _TRAINED_MODELS_SAVE_DIR + "/" + path + model, args = joblib.load(path) + return model, args def load_model(path, print_args=False, jit=False, ignore_subdir=False) -> Tuple[GNN, Args]: - print("Loading model...") - assert ".pt" not in path, f"Found .pt in path {path}" - if ".dt" not in path: - path = path+".dt" - if not ignore_subdir and "trained_models" not in path: - path = "trained_models/" + path - try: - if torch.cuda.is_available(): - model_state_dict, args = torch.load(path) - else: - model_state_dict, args = torch.load(path, map_location=torch.device('cpu')) - except: - print(f"Model not found at {path}") - exit(-1) - # update legacy naming - if "dg-el" in args.rep: - args.rep = args.rep.replace("dg-el", "lg") - model = GNNS[args.model](params=arg_to_params(args), jit=jit) - model.load_state_dict_into_gnn(model_state_dict) - print("Model loaded!") - if print_args: - print_arguments(args) - model.eval() - return model, args + print("Loading model...") + assert ".pt" not in path, f"Found .pt in path {path}" + if ".dt" not in path: + path = path+".dt" + if not ignore_subdir and _TRAINED_MODELS_SAVE_DIR not in path: + path = _TRAINED_MODELS_SAVE_DIR + "/" + path + try: + if torch.cuda.is_available(): + model_state_dict, args = torch.load(path) + else: + model_state_dict, args = torch.load(path, map_location=torch.device('cpu')) + except: + print(f"Model not found at {path}") + exit(-1) + # update legacy naming + if "dg-el" in args.rep: + args.rep = args.rep.replace("dg-el", "lg") + model = GNNS[args.model](params=arg_to_params(args), jit=jit) + model.load_state_dict_into_gnn(model_state_dict) + print("Model loaded!") + if print_args: + print_arguments(args) + model.eval() + return model, args def load_model_and_setup_gnn(path, domain_file, problem_file): - model, args = load_model(path, ignore_subdir=True) - device = "cuda" if torch.cuda.is_available() else "cpu" - model = model.to(device) - model.batch_search(True) - model.update_representation(domain_pddl=domain_file, - problem_pddl=problem_file, - args=args, - device=device) - model.set_zero_grad() - model.eval() - return model - + model, args = load_model(path, ignore_subdir=True) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = model.to(device) + model.batch_search(True) + model.update_representation(domain_pddl=domain_file, + problem_pddl=problem_file, + args=args, + device=device) + model.set_zero_grad() + model.eval() + return model + \ No newline at end of file diff --git a/learner/util/search.py b/learner/util/search.py index 118f5324..89582f01 100644 --- a/learner/util/search.py +++ b/learner/util/search.py @@ -64,25 +64,10 @@ def fd_cmd(rep, df, pf, m, search, seed, timeout=TIMEOUT): else: raise NotImplementedError - # A hack given that FD FeaturePlugin cannot parse strings - # 0: slg, 1: flg, 2: dlg, 3: llg - assert rep in REPRESENTATIONS - config_file = rep - config = { - "slg":0, - "flg":1, - "dlg":2, - "llg":3, - }[rep] - description = f"fd_{pf.replace('.pddl','').replace('/','-')}_{search}_{os.path.basename(m).replace('.dt', '')}" sas_file = f"sas_files/{description}.sas_file" plan_file = f"plans/{description}.plan" - with open(config_file, 'w') as f: - f.write(m+'\n') - f.write(df+'\n') - f.write(pf+'\n') - f.close() - cmd = f'./../downward/fast-downward.py --search-time-limit {timeout} --sas-file {sas_file} --plan-file {plan_file} {df} {pf} --search "{search}([goose(graph={config})])"' + cmd = f"./../downward/fast-downward.py --search-time-limit {timeout} --sas-file {sas_file} --plan-file {plan_file} "+\ + f"{df} {pf} --search '{search}([goose(model_path=\"{m}\", domain_file=\"{df}\", instance_file=\"{pf}\")])'" cmd = f"export GOOSE={os.getcwd()} && {cmd}" return cmd, sas_file diff --git a/learner/util/stats.py b/learner/util/stats.py index ca3d7423..065f0130 100644 --- a/learner/util/stats.py +++ b/learner/util/stats.py @@ -50,25 +50,7 @@ def print_quartiles(desc: str, data: np.array, floats: bool = False): print(f"{desc:<20} {q1:>10} {q2:>10} {q3:>10} {min(data):>10} {max(data):>10}") -def get_y_stats(dataset): - ys = [] - for data in dataset: - y = round(data.y) - ys.append(y) - - ys = np.array(ys) - # os.makedirs("plots/", exist_ok=True) - # plt.hist(ys, bins=round(np.max(ys) + 1), - # range=(0, round(np.max(ys) + 1))) - # plt.xlim(left=0) - # # plt.title('y distribution') - # plt.savefig('plots/y_distribution.pdf', bbox_inches="tight") - # plt.clf() - - return ys - - -def get_stats(dataset, iteration_stats=False, desc=""): +def get_stats(dataset, desc=""): if len(dataset) == 0: return cnt = {} @@ -76,248 +58,37 @@ def get_stats(dataset, iteration_stats=False, desc=""): graph_nodes = [] graph_edges = [] graph_dense = [] - iterations = [] + ys = [] for data in dataset: - y = data.y + if type(dataset[0]) == tuple: # CGraphs + graph, y = data + n_nodes = len(graph.nodes) + n_edges = len(graph.edges) + else: # TGraphs + y = data.y + n_nodes = data.x.shape[0] if data.x is not None else 0 + try: + n_edges = data.edge_index.shape[1] + except: + n_edges = sum(e.shape[1] for e in data.edge_index) + density = graph_density(n_nodes, n_edges, directed=True) + if y not in cnt: cnt[y] = 0 cnt[y] += 1 max_cost = max(max_cost, round(y)) - - if iteration_stats: - iterations.append(data.iterations) - - if data.x is None: - n_nodes = 0 - else: - n_nodes = data.x.shape[0] - try: - n_edges = data.edge_index.shape[1] - except: - # print(data.edge_index) - # for a in data.edge_index: - # print(a) - n_edges = sum(e.shape[1] for e in data.edge_index) - density = graph_density(n_nodes, n_edges, directed=True) graph_nodes.append(n_nodes) graph_edges.append(n_edges) graph_dense.append(density) - - # Cost/y distribution - # print('Cost distribution') - ys = get_y_stats(dataset) + ys.append(y) # Statistics print_quartile_desc(desc) - if iteration_stats: - print_quartiles("iterations:", iterations) print_quartiles("costs:", ys) print_quartiles("n_nodes:", graph_nodes) print_quartiles("n_edges:", graph_edges) print_quartiles("density:", graph_dense, floats=True) return - - -def view_confusion_matrix(plt_title, y_pred, y_true, view_cm, alt_save="", cutoff=-1, fontsize=None, removeaxeslabel=False): - if fontsize is not None: - plt.rcParams.update({'font.size': fontsize}) - y_pred = [round(i) for i in y_pred] - y_true = [round(i) for i in y_true] - # min_true = min(y_true) - # y_pred = y_pred + list(range(min_true)) - # y_true = y_true + list(range(min_true)) - fig, ax = plt.subplots(figsize=(10, 10)) - if cutoff == -1: - cutoff = max(y_true)+1 - cm = confusion_matrix(y_true, y_pred, normalize="true", labels=list(range(0,cutoff))) - display_labels = None - max_y = cm.shape[0] - if max_y >= 50: - display_labels = [] - for y in range(max_y): - if y % 10 == 0: - display_labels.append(y) - else: - display_labels.append("") - disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) - disp.plot(include_values=False, xticks_rotation="vertical", ax=ax, colorbar=False, cmap=plt.cm.Blues) - disp.im_.set_clim(0, 1) - plt_title = str(plt_title) - # plt.title(plt_title) - plt_title = ' '.join(plt_title.split()) - plt_title = plt_title.replace(" ", "_") - plt.axis("off") - if removeaxeslabel: - plt.gca().xaxis.label.set_visible(False) - plt.gca().yaxis.label.set_visible(False) - if alt_save != "": - # alt_save = alt_save.replace(".pdf", "") - # alt_save = alt_save.replace(".png", "") - plt.savefig(f"{alt_save}", bbox_inches="tight") - else: - plt_title = plt_title.replace(".pdf", "") - plt_title = plt_title.replace(".png", "") - plt.savefig(f"plots/{plt_title}.pdf") - if view_cm: - print(f"Showing {plt_title}") - plt.show() - plt.clf() - return - - -@torch.no_grad() -def visualise_loader_stats(model, device, loader, title): - # visualise_train_stats so disgusting so just make another one here - model.eval() - y_true = torch.tensor([]) - y_pred = torch.tensor([]) - for data in tqdm(loader): - data = data.to(device) - y = data.y - out = model.forward(data) - - y_pred = torch.cat((y_pred, out.detach().cpu())) - y_true = torch.cat((y_true, y.detach().cpu())) - - loss = torch.nn.MSELoss()(y_pred, y_true) - macro_f1, micro_f1 = eval_f1_score(y_pred=y_pred, y_true=y_true) - admis = eval_admissibility(y_pred=y_pred, y_true=y_true) - print(f"size: {len(y_true)}") - print(f"loss: {loss:.2f}") - print(f"f1: {macro_f1:.1f}") - print(f"admissibility: {admis:.1f}") - title = f"{title} f1={macro_f1:.1f} loss={loss:.2f}" - view_confusion_matrix(title, y_pred.tolist(), y_true.tolist(), view_cm=True) - return - - -@torch.no_grad() -def visualise_train_stats(model, device, train_loader, val_loader=None, test_loader=None, max_cost=20, print_stats=True, - classify=False, view_cm=False, cm_train="cm_train", cm_val="cm_val", cm_test="cm_test"): - model = model.to(device) - model.eval() - - def get_stats_from_loader(loader): - preds = [] - true = [] - errors = [[] for _ in range(max_cost+1)] - for batch in tqdm(loader): - batch = batch.to(device) - y = batch.y - out = model.forward(batch) - if classify: - out = torch.argmax(out, dim=1) - else: - out = torch.maximum(out, torch.zeros_like(out)) # so h is nonzero - batch_errors = (y - out) / y - for i in range(len(y)): - e = batch_errors[i].detach().cpu().item() - c = y[i].detach().cpu().item() - o = out[i].detach().cpu().item() - preds.append(round(o)) - true.append(c) - errors[0].append(e) - if c <= max_cost: - errors[round(c)].append(c - o) - errors[0] = np.array(errors[0]) - errors[0][np.isnan(errors[0])] = 0 - preds = np.array(preds) - true = np.array(true) - return preds, true, errors - - print("Collecting stats...") - - # print("Prediction value set", np.unique(train_preds, return_counts=True)) - os.makedirs("plots", exist_ok=True) - for fname in ["error_prop", "preds_train", "error_train", "preds_val", "error_val", "preds_test", "error_test"]: - try: - os.remove(f"plots/{fname}.png") - except: - pass - - boxes = [] - ticks = [] - - if train_loader is not None: - train_preds, train_true, train_errors = get_stats_from_loader(train_loader) - view_confusion_matrix(plt_title=cm_train, y_true=train_true, y_pred=train_preds, view_cm=view_cm) - # boxes.append(train_errors[0]) - # ticks.append((len(boxes), 'train')) - # plt.hist(train_preds, bins=round(np.max(train_preds) + 1), - # range=(0, round(np.max(train_preds) + 1))) - # plt.title('Train prediction distribution') - # plt.savefig('plots/preds_train', dpi=480) - # plt.clf() - # - # plt.boxplot([train_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Train error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_train', dpi=480) - # plt.clf() - if val_loader is not None: - val_preds, val_true, val_errors = get_stats_from_loader(val_loader) - view_confusion_matrix(plt_title=cm_val, y_true=val_true, y_pred=val_preds, view_cm=view_cm) - # boxes.append(val_errors[0]) - # ticks.append((len(boxes), 'val')) - # plt.hist(val_preds, bins=round(np.max(val_preds) + 1), - # range=(0, round(np.max(val_preds) + 1))) - # plt.title('Validation prediction distribution') - # plt.savefig('plots/preds_val', dpi=480) - # plt.clf() - # - # plt.boxplot([val_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Val error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_val', dpi=480) - # plt.clf() - if test_loader is not None: - test_preds, test_true, test_errors = get_stats_from_loader(test_loader) - view_confusion_matrix(plt_title=cm_test, y_true=test_true, y_pred=test_preds, view_cm=view_cm) - # boxes.append(test_errors[0]) - # ticks.append((len(boxes), 'test')) - # plt.hist(test_preds, bins=round(np.max(test_preds) + 1), - # range=(0, round(np.max(test_preds) + 1))) - # plt.title('Test prediction distribution') - # plt.savefig('plots/preds_val', dpi=480) - # plt.clf() - # - # plt.boxplot([test_errors[i] for i in range(1, max_cost + 1)]) - # plt.title('Test error differences over states away from target') - # plt.ylim((-4, 4)) - # plt.tight_layout() - # plt.savefig('plots/error_test', dpi=480) - # plt.clf() - - print("Plotting done!") - - # Statistics - if print_stats: - print("{0:<20} {1:>10} {2:>10} {3:>10} {4:>10} {5:>10}".format(" ", "Q1", "median", "Q3", "min", "max")) - if train_loader is not None: - print_quartiles("train prop_err:", train_errors[0], floats=True) - if val_loader is not None: - print_quartiles("val prop_err:", val_errors[0], floats=True) - if test_loader is not None: - print_quartiles("test prop_err:", test_errors[0], floats=True) - print("% admissible") - if train_loader is not None: - print(f"train: {np.count_nonzero(train_errors[0] > 0) / len(train_errors[0]):.2f}") - if val_loader is not None: - print(f"val: {np.count_nonzero(val_errors[0] > 0) / len(val_errors[0]):.2f}") - if test_loader is not None: - print(f"test: {np.count_nonzero(test_errors[0] > 0) / len(test_errors[0])}:.2f") - - # plt.boxplot(boxes) - # plt.xticks(ticks) - # plt.ylim((-1, 1)) - # plt.title('Proportion errors') - # plt.tight_layout() - # plt.savefig('plots/error_prop', dpi=480) - # plt.clf() - - return diff --git a/learner/util/visualise.py b/learner/util/visualise.py index 58ae8ffe..65329019 100644 --- a/learner/util/visualise.py +++ b/learner/util/visualise.py @@ -1,6 +1,8 @@ import os import sys +from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) import re @@ -320,3 +322,20 @@ def display_solved_test_stats(train_type, L, H, aggr, p): def get_max_of_parameters(df): df = df.drop(columns=["L", "aggr"]).max() return df + +def get_confusion_matrix(y_true_train, y_pred_train, y_true_test, y_pred_test, cutoff=-1): + y_true_train = np.rint(y_true_train).astype(int) + y_pred_train = np.rint(y_pred_train).astype(int) + y_true_test = np.rint(y_true_test).astype(int) + y_pred_test = np.rint(y_pred_test).astype(int) + fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 10)) + if cutoff == -1: + cutoff = max(max(y_true_train), max(y_true_test))+1 + cm_train = confusion_matrix(y_true_train, y_pred_train, normalize="true", labels=list(range(0, cutoff))) + cm_test = confusion_matrix(y_true_test, y_pred_test, normalize="true", labels=list(range(0, cutoff))) + display_labels = [y if y%10==0 else "" for y in range(cutoff)] + for i, cm in enumerate([cm_train, cm_test]): + disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels) + disp.plot(include_values=False, xticks_rotation="vertical", ax=ax[i], colorbar=False) + disp.im_.set_clim(0, 1) + return plt diff --git a/requirements.txt b/requirements.txt index 778299e0..8d9d1b40 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ numpy==1.25.0 pandas==1.5.2 plotly==5.15.0 pytest==7.4.0 -scikit_learn==1.2.0 +scikit_learn==1.3.0 scipy==1.9.3 seaborn==0.12.2 torch==2.0.1