diff --git a/HiPRGen/lmdb_dataset.py b/HiPRGen/lmdb_dataset.py index 8a7111d..4034c1c 100644 --- a/HiPRGen/lmdb_dataset.py +++ b/HiPRGen/lmdb_dataset.py @@ -31,22 +31,50 @@ def __init__(self, config, transform=None): self.config = config self.path = Path(self.config["src"]) - # Get metadata in case - # self.metadata_path = self.path.parent / "metadata.npz" - self.env = self.connect_db(self.path) - - # If "length" encoded as ascii is present, use that - # If there are additional properties, there must be length. - length_entry = self.env.begin().get("length".encode("ascii")) - if length_entry is not None: - num_entries = pickle.loads(length_entry) + if not self.path.is_file(): + db_paths = sorted(self.path.glob("*.lmdb")) + assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" + #self.metadata_path = self.path / "metadata.npz" + + self._keys = [] + self.envs = [] + for db_path in db_paths: + cur_env = self.connect_db(db_path) + self.envs.append(cur_env) + + # If "length" encoded as ascii is present, use that + length_entry = cur_env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries in the LMDB + num_entries = cur_env.stat()["entries"] + + # Append the keys (0->num_entries) as a list + self._keys.append(list(range(num_entries))) + + keylens = [len(k) for k in self._keys] + self._keylen_cumulative = np.cumsum(keylens).tolist() + self.num_samples = sum(keylens) + + else: - # Get the number of stores data from the number of entries - # in the LMDB - num_entries = self.env.stat()["entries"] - - self._keys = list(range(num_entries)) - self.num_samples = num_entries + # Get metadata in case + # self.metadata_path = self.path.parent / "metadata.npz" + self.env = self.connect_db(self.path) + + # If "length" encoded as ascii is present, use that + # If there are additional properties, there must be length. + length_entry = self.env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = self.env.stat()["entries"] + + self._keys = list(range(num_entries)) + self.num_samples = num_entries # Get portion of total dataset self.sharded = False @@ -71,15 +99,34 @@ def __getitem__(self, idx): # if sharding, remap idx to appropriate idx of the sharded set if self.sharded: idx = self.available_indices[idx] + + if not self.path.is_file(): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self._keylen_cumulative, idx) + # Extract index of element within that db. + el_idx = idx + if db_idx != 0: + el_idx = idx - self._keylen_cumulative[db_idx - 1] + assert el_idx >= 0 + + # Return features. + datapoint_pickled = ( + self.envs[db_idx] + .begin() + .get(f"{self._keys[db_idx][el_idx]}".encode("ascii")) + ) + data_object = pickle.loads(datapoint_pickled) + #data_object.id = f"{db_idx}_{el_idx}" + + else: + #!CHECK, _keys should be less then total numbers of keys as there are more properties. + datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii")) - #!CHECK, _keys should be less then total numbers of keys as there are more properties. - datapoint_pickled = self.env.begin().get(f"{self._keys[idx]}".encode("ascii")) - - data_object = pickle.loads(datapoint_pickled) + data_object = pickle.loads(datapoint_pickled) - # TODO - if self.transform is not None: - data_object = self.transform(data_object) + # TODO + if self.transform is not None: + data_object = self.transform(data_object) return data_object @@ -109,25 +156,30 @@ def get_metadata(self, num_samples=100): class LmdbMoleculeDataset(LmdbBaseDataset): def __init__(self, config, transform=None): super(LmdbMoleculeDataset, self).__init__(config=config, transform=transform) - + if not self.path.is_file(): + self.env_ = self.envs[0] + raise("Not Implemented Yet") + + else: + self.env_ = self.env @property def charges(self): - charges = self.env.begin().get("charges".encode("ascii")) + charges = self.env_.begin().get("charges".encode("ascii")) return pickle.loads(charges) @property def ring_sizes(self): - ring_sizes = self.env.begin().get("ring_sizes".encode("ascii")) + ring_sizes = self.env_.begin().get("ring_sizes".encode("ascii")) return pickle.loads(ring_sizes) @property def elements(self): - elements = self.env.begin().get("elements".encode("ascii")) + elements = self.env_.begin().get("elements".encode("ascii")) return pickle.loads(elements) @property def feature_info(self): - feature_info = self.env.begin().get("feature_info".encode("ascii")) + feature_info = self.env_.begin().get("feature_info".encode("ascii")) return pickle.loads(feature_info) @@ -135,30 +187,83 @@ class LmdbReactionDataset(LmdbBaseDataset): def __init__(self, config, transform=None): super(LmdbReactionDataset, self).__init__(config=config, transform=transform) + if not self.path.is_file(): + self.env_ = self.envs[0] + #get keys + for i in range(1, len(self.envs)): + for key in ["feature_size", "dtype", "feature_name"]: #, "mean", "std"]: + assert self.envs[i].begin().get(key.encode("ascii")) == self.envs[0].begin().get(key.encode("ascii")) + #! mean and std are not equal across different dataset at this time. + #get mean and std + mean_list = [pickle.loads(self.envs[i].begin().get("mean".encode("ascii"))) for i in range(0, len(self.envs))] + std_list = [pickle.loads(self.envs[i].begin().get("std".encode("ascii"))) for i in range(0, len(self.envs))] + count_list = [pickle.loads(self.envs[i].begin().get("length".encode("ascii"))) for i in range(0, len(self.envs))] + self._mean, self._std = combined_mean_std(mean_list, std_list, count_list) + + else: + self.env_ = self.env + self._mean = pickle.loads(self.env_.begin().get("mean".encode("ascii"))) + self._std = pickle.loads(self.env_.begin().get("std".encode("ascii"))) + @property def dtype(self): - dtype = self.env.begin().get("dtype".encode("ascii")) + dtype = self.env_.begin().get("dtype".encode("ascii")) return pickle.loads(dtype) - + @property def feature_size(self): - feature_size = self.env.begin().get("feature_size".encode("ascii")) + feature_size = self.env_.begin().get("feature_size".encode("ascii")) return pickle.loads(feature_size) @property def feature_name(self): - feature_name = self.env.begin().get("feature_name".encode("ascii")) + feature_name = self.env_.begin().get("feature_name".encode("ascii")) return pickle.loads(feature_name) - + @property def mean(self): - mean = self.env.begin().get("mean".encode("ascii")) - return pickle.loads(mean) - + return self._mean + @property def std(self): - std = self.env.begin().get("std".encode("ascii")) - return pickle.loads(std) + #std = self.env_.begin().get("std".encode("ascii")) + return self._std + +# @property +# def mean(self): +# mean = self.env_.begin().get("mean".encode("ascii")) +# return pickle.loads(mean) + +# @property +# def std(self): +# std = self.env_.begin().get("std".encode("ascii")) +# return pickle.loads(std) + + +def combined_mean_std(mean_list, std_list, count_list): + """ + Calculate the combined mean and standard deviation of multiple datasets. + + :param mean_list: List of means of the datasets. + :param std_list: List of standard deviations of the datasets. + :param count_list: List of number of data points in each dataset. + :return: Combined mean and standard deviation. + """ + # Calculate total number of data points + total_count = sum(count_list) + + # Calculate combined mean + combined_mean = sum(mean * count for mean, count in zip(mean_list, count_list)) / total_count + + # Calculate combined variance + combined_variance = sum( + ((std ** 2) * (count - 1) + count * (mean - combined_mean) ** 2 for mean, std, count in zip(mean_list, std_list, count_list)) + ) / (total_count - len(mean_list)) + + # Calculate combined standard deviation + combined_std = (combined_variance ** 0.5) + + return combined_mean, combined_std @@ -442,10 +547,10 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): map_async=True, ) - pbar = tqdm( - total=len(new_samples), - desc=f"Adding new samples into LMDBs", - ) + # pbar = tqdm( + # total=len(new_samples), + # desc=f"Adding new samples into LMDBs", + # ) #write indexed samples idx = current_length @@ -456,7 +561,7 @@ def write_to_lmdb(new_samples, current_length, lmdb_update, db_path): pickle.dumps(sample, protocol=-1), ) idx += 1 - pbar.update(1) + #pbar.update(1) txn.commit() #write properties diff --git a/HiPRGen/reaction_filter.py b/HiPRGen/reaction_filter.py index 53d6d83..914e6d9 100644 --- a/HiPRGen/reaction_filter.py +++ b/HiPRGen/reaction_filter.py @@ -104,12 +104,13 @@ def log_message(*args, **kwargs): '[' + strftime('%H:%M:%S', localtime()) + ']', *args, **kwargs) +#restructure input of dispatcher def dispatcher( - mol_entries, - dgl_molecules_dict, - grapher_features, - dispatcher_payload, - reaction_lmdb_path + mol_entries, #1 + #dgl_molecules_dict, + #grapher_features, + dispatcher_payload, #2 + #reaction_lmdb_path ): comm = MPI.COMM_WORLD @@ -138,16 +139,17 @@ def dispatcher( #### HY ## initialize preprocess data - -#wx: writting lmdbs in dispatcher ? - rxn_networks_g = rxn_networks_graph( - mol_entries, - dgl_molecules_dict, - grapher_features, - dispatcher_payload.bondnet_test, - reaction_lmdb_path - ) - #### + +# #wx: writting lmdbs in dispatcher ? +# #wx: each worker needs to initlize rxn_networks_graph at worker level. +# rxn_networks_g = rxn_networks_graph( +# mol_entries, +# dgl_molecules_dict, +# grapher_features, +# dispatcher_payload.bondnet_test, +# reaction_lmdb_path #wx. different +# ) +# #### log_message("initializing report generator") @@ -162,6 +164,7 @@ def dispatcher( worker_states = {} worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK] + print("worker_states",worker_states) for i in worker_ranks: worker_states[i] = WorkerState.INITIALIZING @@ -173,6 +176,7 @@ def dispatcher( log_message("all workers running") + #global index which is different with local index in worker reaction_index = 0 log_message("handling requests") @@ -209,6 +213,7 @@ def dispatcher( tag = status.Get_tag() rank = status.Get_source() + #this is the last step when worker is out of work. if tag == SEND_ME_A_WORK_BATCH: if len(work_batch_list) == 0: comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH) @@ -224,9 +229,11 @@ def dispatcher( ": group ids:", group_id_0, group_id_1 ) - - - elif tag == NEW_REACTION_DB: + #this is where worker is doing things. found a good reation and send to dispatcher. + #if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher. + #This is the reason why adding samples one by one. QA: where is batch of reactions? +#ten reactions, first filter out, second send , next eight + elif tag == NEW_REACTION_DB: reaction = data rn_cur.execute( insert_reaction, @@ -243,8 +250,10 @@ def dispatcher( reaction['is_redox'] )) - # Create reaction graph + add to LMDB - rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) + # # # Create reaction graph + add to LMDB + # rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level + +#dispatch tracks global index, worker tracks local index in that batch. reaction_index += 1 if reaction_index % dispatcher_payload.commit_frequency == 0: @@ -278,17 +287,37 @@ def dispatcher( def worker( - mol_entries, - worker_payload + mol_entries, #input of worker + worker_payload, + dgl_molecules_dict, + grapher_features, + reaction_lmdb_path + ): + # import pdb + # pdb.set_trace() + + local_reaction_idx = 0 #wx add local_idx + comm = MPI.COMM_WORLD con = sqlite3.connect(worker_payload.bucket_db_file) cur = con.cursor() - comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED) + rank = comm.Get_rank() #get id of that worker + + lmdb_name_i = reaction_lmdb_path.split(".lmdb")[0] + "_" + str(rank) + ".lmdb" + + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + #dispatcher_payload.bondnet_test, #can be removed + lmdb_name_i #wx. different + ) + while True: comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH) work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH) @@ -355,6 +384,10 @@ def worker( dest=DISPATCHER_RANK, tag=NEW_REACTION_DB) + #comm.send, send reaction to dispatchers. + + rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx) + local_reaction_idx+=1 if run_decision_tree(reaction, @@ -368,4 +401,4 @@ def worker( ), dest=DISPATCHER_RANK, - tag=NEW_REACTION_LOGGING) \ No newline at end of file + tag=NEW_REACTION_LOGGING) diff --git a/HiPRGen/reaction_filter_cp.py b/HiPRGen/reaction_filter_cp.py new file mode 100644 index 0000000..c5bd263 --- /dev/null +++ b/HiPRGen/reaction_filter_cp.py @@ -0,0 +1,396 @@ +from mpi4py import MPI +from HiPRGen.rxn_networks_graph import rxn_networks_graph +from itertools import permutations, product +from HiPRGen.report_generator import ReportGenerator +import sqlite3 +from time import localtime, strftime, time +from enum import Enum +from math import floor +from HiPRGen.reaction_filter_payloads import ( + DispatcherPayload, + WorkerPayload +) + +from HiPRGen.reaction_questions import ( + run_decision_tree +) + + +""" +Phases 3 & 4 run in parallel using MPI + +Phase 3: reaction gen and filtering +input: a bucket labeled by atom count +output: a list of reactions from that bucket +description: Loop through all possible reactions in the bucket and apply the decision tree. This will run in parallel over each bucket. + +Phase 4: collating and indexing +input: all the outputs of phase 3 as they are generated +output: reaction network database +description: the worker processes from phase 3 are sending their reactions to this phase and it is writing them to DB as it gets them. We can ensure that duplicates don't get generated in phase 3 which means we don't need extra index tables on the db. + +the code in this file is designed to run on a compute cluster using MPI. +""" + + +create_metadata_table = """ + CREATE TABLE metadata ( + number_of_species INTEGER NOT NULL, + number_of_reactions INTEGER NOT NULL + ); +""" + +insert_metadata = """ + INSERT INTO metadata VALUES (?, ?) +""" + +# it is important that reaction_id is the primary key +# otherwise the network loader will be extremely slow. +create_reactions_table = """ + CREATE TABLE reactions ( + reaction_id INTEGER NOT NULL PRIMARY KEY, + number_of_reactants INTEGER NOT NULL, + number_of_products INTEGER NOT NULL, + reactant_1 INTEGER NOT NULL, + reactant_2 INTEGER NOT NULL, + product_1 INTEGER NOT NULL, + product_2 INTEGER NOT NULL, + rate REAL NOT NULL, + dG REAL NOT NULL, + dG_barrier REAL NOT NULL, + is_redox INTEGER NOT NULL + ); +""" + + +insert_reaction = """ + INSERT INTO reactions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) +""" + +get_complex_group_sql = """ + SELECT * FROM complexes WHERE composition_id=? AND group_id=? +""" + + +# TODO: structure these global variables better +DISPATCHER_RANK = 0 + +# message tags + +# sent by workers to the dispatcher once they have finished initializing +# only sent once +INITIALIZATION_FINISHED = 0 + +# sent by workers to the dispatcher to request a new table +SEND_ME_A_WORK_BATCH = 1 + +# sent by dispatcher to workers when delivering a new table +HERE_IS_A_WORK_BATCH = 2 + +# sent by workers to the dispatcher when reaction passes db decision tree +NEW_REACTION_DB = 3 + +# sent by workers to the dispatcher when reaction passes logging decision tree +NEW_REACTION_LOGGING = 4 + +class WorkerState(Enum): + INITIALIZING = 0 + RUNNING = 1 + FINISHED = 2 + + +def log_message(*args, **kwargs): + print( + '[' + strftime('%H:%M:%S', localtime()) + ']', + *args, **kwargs) + +def dispatcher( #input of dispatcher. + mol_entries, #1 + dgl_molecules_dict, + grapher_features, + dispatcher_payload, #2 + #wx + reaction_lmdb_path +): + + comm = MPI.COMM_WORLD + work_batch_list = [] + bucket_con = sqlite3.connect(dispatcher_payload.bucket_db_file) + bucket_cur = bucket_con.cursor() + size_cur = bucket_con.cursor() + + res = bucket_cur.execute("SELECT * FROM group_counts") + for (composition_id, count) in res: + for (i,j) in product(range(count), repeat=2): + work_batch_list.append( + (composition_id, i, j)) + + composition_names = {} + res = bucket_cur.execute("SELECT * FROM compositions") + for (composition_id, composition) in res: + composition_names[composition_id] = composition + + log_message("creating reaction network db") + rn_con = sqlite3.connect(dispatcher_payload.reaction_network_db_file) + rn_cur = rn_con.cursor() + rn_cur.execute(create_metadata_table) + rn_cur.execute(create_reactions_table) + rn_con.commit() + + #### HY + ## initialize preprocess data + +#wx: writting lmdbs in dispatcher ? +#wx: each worker needs to initlize rxn_networks_graph at worker level. + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + dispatcher_payload.bondnet_test, + reaction_lmdb_path #wx. different + ) + #### + + log_message("initializing report generator") + + # since MPI processes spin lock, we don't want to have the dispathcer + # spend a bunch of time generating molecule pictures + report_generator = ReportGenerator( + mol_entries, + dispatcher_payload.report_file, + rebuild_mol_pictures=False + ) + + worker_states = {} + + worker_ranks = [i for i in range(comm.Get_size()) if i != DISPATCHER_RANK] + + for i in worker_ranks: + worker_states[i] = WorkerState.INITIALIZING + + for i in worker_states: + # block, waiting for workers to initialize + comm.recv(source=i, tag=INITIALIZATION_FINISHED) + worker_states[i] = WorkerState.RUNNING + + log_message("all workers running") + + reaction_index = 0 + + log_message("handling requests") + + batches_left_at_last_checkpoint = len(work_batch_list) + last_checkpoint_time = floor(time()) + while True: + if WorkerState.RUNNING not in worker_states.values(): + break + + current_time = floor(time()) + time_diff = current_time - last_checkpoint_time + if ( current_time % dispatcher_payload.checkpoint_interval == 0 and + time_diff > 0): + batches_left_at_current_checkpoint = len(work_batch_list) + batch_count_diff = ( + batches_left_at_last_checkpoint - + batches_left_at_current_checkpoint) + + batch_consumption_rate = batch_count_diff / time_diff + + log_message("batches remaining:", batches_left_at_current_checkpoint) + log_message("batch consumption rate:", + batch_consumption_rate, + "batches per second") + + + batches_left_at_last_checkpoint = batches_left_at_current_checkpoint + last_checkpoint_time = current_time + + + status = MPI.Status() + data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) + tag = status.Get_tag() + rank = status.Get_source() + + #this is the last step when worker is out of work. + if tag == SEND_ME_A_WORK_BATCH: + if len(work_batch_list) == 0: + comm.send(None, dest=rank, tag=HERE_IS_A_WORK_BATCH) + worker_states[rank] = WorkerState.FINISHED + else: + # pop removes and returns the last item in the list + work_batch = work_batch_list.pop() + comm.send(work_batch, dest=rank, tag=HERE_IS_A_WORK_BATCH) + composition_id, group_id_0, group_id_1 = work_batch + log_message( + "dispatched", + composition_names[composition_id], + ": group ids:", + group_id_0, group_id_1 + ) + #this is where worker is doing things. found a good reation and send to dispatcher. + #if this is correct, then create_rxn_networks_graph operates on worker instead of dispatcher. + #This is the reason why adding samples one by one. QA: where is batch of reactions? +#ten reactions, first filter out, second send , next eight + elif tag == NEW_REACTION_DB: + reaction = data + rn_cur.execute( + insert_reaction, + (reaction_index, + reaction['number_of_reactants'], + reaction['number_of_products'], + reaction['reactants'][0], + reaction['reactants'][1], + reaction['products'][0], + reaction['products'][1], + reaction['rate'], + reaction['dG'], + reaction['dG_barrier'], + reaction['is_redox'] + )) + + # # Create reaction graph + add to LMDB + rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) #wx in worker level + +#dispatch tracks global index, worker tracks local index in that batch. + + reaction_index += 1 + if reaction_index % dispatcher_payload.commit_frequency == 0: + rn_con.commit() + + + elif tag == NEW_REACTION_LOGGING: + + reaction = data[0] + decision_path = data[1] + + report_generator.emit_verbatim(decision_path) + report_generator.emit_reaction(reaction) + report_generator.emit_bond_breakage(reaction) + report_generator.emit_newline() + + + + log_message("finalzing database and generation report") + rn_cur.execute( + insert_metadata, + (len(mol_entries), + reaction_index) + ) + + + report_generator.finished() + rn_con.commit() + bucket_con.close() + rn_con.close() + + +def worker( + mol_entries, #input of worker + worker_payload +): + +#wx + local_reaction_idx = 0 + + comm = MPI.COMM_WORLD + con = sqlite3.connect(worker_payload.bucket_db_file) + cur = con.cursor() + + comm.send(None, dest=DISPATCHER_RANK, tag=INITIALIZATION_FINISHED) + +#wx + rank = comm.Get_rank() #id of that worker + + rxn_networks_g = rxn_networks_graph( + mol_entries, + dgl_molecules_dict, + grapher_features, + #dispatcher_payload.bondnet_test, + reaction_lmdb_path + rank #wx. different + ) + + + while True: + comm.send(None, dest=DISPATCHER_RANK, tag=SEND_ME_A_WORK_BATCH) + work_batch = comm.recv(source=DISPATCHER_RANK, tag=HERE_IS_A_WORK_BATCH) + + if work_batch is None: + break + + + composition_id, group_id_0, group_id_1 = work_batch + + + if group_id_0 == group_id_1: + + res = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket = [] + for row in res: + bucket.append((row[0],row[1])) + + iterator = permutations(bucket, r=2) + + else: + + res_0 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_0)) + + bucket_0 = [] + for row in res_0: + bucket_0.append((row[0],row[1])) + + res_1 = cur.execute( + get_complex_group_sql, + (composition_id, group_id_1)) + + bucket_1 = [] + for row in res_1: + bucket_1.append((row[0],row[1])) + + iterator = product(bucket_0, bucket_1) + + + + for (reactants, products) in iterator: + reaction = { + 'reactants' : reactants, + 'products' : products, + 'number_of_reactants' : len([i for i in reactants if i != -1]), + 'number_of_products' : len([i for i in products if i != -1])} + + + decision_pathway = [] + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.reaction_decision_tree, + decision_pathway + ): + + comm.send( + reaction, + dest=DISPATCHER_RANK, + tag=NEW_REACTION_DB) + + #comm.send, send reaction to dispatchers. + + rxn_networks_g.create_rxn_networks_graph(reaction, local_reaction_idx) + local_reaction_idx+=1 + + + if run_decision_tree(reaction, + mol_entries, + worker_payload.params, + worker_payload.logging_decision_tree): + + comm.send( + (reaction, + '\n'.join([str(f) for f in decision_pathway]) + ), + + dest=DISPATCHER_RANK, + tag=NEW_REACTION_LOGGING) diff --git a/HiPRGen/rxn_networks_graph.py b/HiPRGen/rxn_networks_graph.py index eb7721c..c6d1949 100644 --- a/HiPRGen/rxn_networks_graph.py +++ b/HiPRGen/rxn_networks_graph.py @@ -19,14 +19,14 @@ def __init__( mol_entries, dgl_molecules_dict, grapher_features, - report_file_path, + #report_file_path, reaction_lmdb_path #wx ): #wx, which one should come from molecule lmdbs? self.mol_entries = mol_entries self.dgl_mol_dict = dgl_molecules_dict self.grapher_features = grapher_features - self.report_file_path = report_file_path + #self.report_file_path = report_file_path self.reaction_lmdb_path = reaction_lmdb_path diff --git a/run_network_generation.py b/run_network_generation.py index d42e9af..d076837 100644 --- a/run_network_generation.py +++ b/run_network_generation.py @@ -36,16 +36,22 @@ if rank == DISPATCHER_RANK: dispatcher_payload = loadfn(dispatcher_payload_json) - dispatcher(mol_entries, - dgl_molecules_dict_pickle_file, - grapher_features_dict_pickle_file, - dispatcher_payload, - #wx, - reaction_lmdb_path - ) - + dispatcher(mol_entries, + dispatcher_payload) + + #move to worker level + # dispatcher(mol_entries, + # dgl_molecules_dict_pickle_file, + # grapher_features_dict_pickle_file, + # dispatcher_payload, + # #wx, + # reaction_lmdb_path + # ) else: worker_payload = loadfn(worker_payload_json) worker(mol_entries, - worker_payload + worker_payload, + dgl_molecules_dict_pickle_file, + grapher_features_dict_pickle_file, + reaction_lmdb_path ) diff --git a/test.py b/test.py index 8eee93b..223039b 100644 --- a/test.py +++ b/test.py @@ -878,6 +878,11 @@ def euvl_phase2_test(): return tests_passed + + + + + def euvl_bondnet_test(): start_time = time.time() @@ -890,7 +895,11 @@ def euvl_bondnet_test(): ## HY bondnet_test_json = "./scratch/euvl_phase2_test/reaction_networks_graphs" + lmdbs_path_mol = "./scratch/euvl_phase2_test/lmdbs/mol" + lmdbs_path_reaction = "./scratch/euvl_phase2_test/lmdbs/reaction" subprocess.run(["mkdir", bondnet_test_json]) + subprocess.run(["mkdir", "-p",lmdbs_path_mol]) + subprocess.run(["mkdir", "-p",lmdbs_path_reaction]) ## species_decision_tree = euvl_species_decision_tree @@ -900,8 +909,13 @@ def euvl_bondnet_test(): "electron_free_energy": 0.0, } + # with open(folder + "/mol_entries.pickle", 'rb') as f: + # mol_entries = pickle.load(f) + # import pdb # pdb.set_trace() + # with open("/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test/mol_entries.pickle", 'rb') as f: + # mol_entries = pickle.load(f) mol_entries, dgl_molecules_dict = species_filter( #wx: dump mol lmdb at the end of species filter. @@ -914,7 +928,7 @@ def euvl_bondnet_test(): coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])), species_logging_decision_tree=species_decision_tree, generate_unfiltered_mol_pictures=False, - mol_lmdb_path = folder + "/mol.lmdb", + mol_lmdb_path = folder + "/lmdbs/mol/mol.lmdb", ) @@ -951,14 +965,18 @@ def euvl_bondnet_test(): number_of_threads, "python", "run_network_generation.py", + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/mol_entries.pickle", folder + "/mol_entries.pickle", folder + "/dispatcher_payload.json", folder + "/worker_payload.json", + + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/dgl_mol_graphs.pickle", + # "/global/home/users/wenbinxu/data/rep/rep/HiPRGen/test/euvl_phase2_test" + "/grapher_features.pickle", folder + "/dgl_mol_graphs.pickle", folder + "/grapher_features.pickle", #wx, path to write reaction lmdb - folder + "/reaction.lmdb" + folder + "/lmdbs/reaction/reaction.lmdb" ] )