diff --git a/HiPRGen/reaction_filter.py b/HiPRGen/reaction_filter.py index 5e672f1..02f2c7f 100644 --- a/HiPRGen/reaction_filter.py +++ b/HiPRGen/reaction_filter.py @@ -240,7 +240,9 @@ def dispatcher( reaction['is_redox'] )) + # Create reaction graph + add to LMDB rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) + reaction_index += 1 if reaction_index % dispatcher_payload.commit_frequency == 0: rn_con.commit() @@ -350,12 +352,6 @@ def worker( dest=DISPATCHER_RANK, tag=NEW_REACTION_DB) - # rxn_networks_g.create_rxn_networks_graph(reaction, reaction_index) - # reaction_index += 1 - # if reaction_index % dispatcher_payload.commit_frequency == 0: - # rn_con.commit() - - if run_decision_tree(reaction, diff --git a/HiPRGen/rxn_networks_graph.py b/HiPRGen/rxn_networks_graph.py index 277731b..e94b005 100644 --- a/HiPRGen/rxn_networks_graph.py +++ b/HiPRGen/rxn_networks_graph.py @@ -5,7 +5,7 @@ import copy from collections import defaultdict from monty.serialization import dumpfn -from bondnet.data.utils import create_rxn_graph +from bondnet.data.utils import construct_rxn_graph_empty from HiPRGen.lmdb_dataset import LmdbDataset import lmdb import tqdm @@ -229,36 +229,40 @@ def find_total_bonds(rxn, species, reactants, products): # print(f"has_bonds: {has_bonds}") # print(f"mappings: {mappings}") - # step 5: Create a reaction graphs and features - rxn_graph, features = create_rxn_graph( - reactants = reactants_dgl_graphs, - products = products_dgl_graphs, - mappings = mappings, - has_bonds = has_bonds, - device = None, - ntypes=("global", "atom", "bond"), - ft_name="feat", - reverse=False, - ) - - # print(f"rxn_graph: {rxn_graph}") - if rxn['is_redox']: - print(f"mappings: {mappings}") - print(f"features: {features}") - print(f"transformed_atom_map: {transformed_atom_map}") - print(f"atom_map: {atom_map}") + # # step 5: Create a reaction graphs and features + # rxn_graph, features = create_rxn_graph( + # reactants = reactants_dgl_graphs, + # products = products_dgl_graphs, + # mappings = mappings, + # has_bonds = has_bonds, + # device = None, + # ntypes=("global", "atom", "bond"), + # ft_name="feat", + # reverse=False, + # zero_fts=True, + # ) + + # # print(f"rxn_graph: {rxn_graph}") + # if rxn['is_redox']: + # print(f"mappings: {mappings}") + # print(f"features: {features}") + # print(f"transformed_atom_map: {transformed_atom_map}") + # print(f"atom_map: {atom_map}") + + # # step 5: update reaction features to the reaction graph + # for nt, ft in features.items(): + # # print(f"nt: {nt}") + # # print(f"ft: {ft}") + # rxn_graph.nodes[nt].data.update({'ft': ft}) - # step 5: update reaction features to the reaction graph - for nt, ft in features.items(): - # print(f"nt: {nt}") - # print(f"ft: {ft}") - rxn_graph.nodes[nt].data.update({'ft': ft}) + rxn_graph = construct_rxn_graph_empty(mappings) # step 6: save a reaction graph and dG self.data[rxn_id] = {} # {'id': {}} self.data[rxn_id]['rxn_graph'] = rxn_graph self.data[rxn_id]['value'] = rxn['dG'] #torch.tensor([rxn['dG']]) - self.data[rxn_id]['reaction_features'] = features + self.data[rxn_id]['mappings'] = mappings + # self.data[rxn_id]['reaction_features'] = features #### Write LMDB #### diff --git a/HiPRGen/species_filter.py b/HiPRGen/species_filter.py index c1562dd..82d87cd 100644 --- a/HiPRGen/species_filter.py +++ b/HiPRGen/species_filter.py @@ -305,6 +305,9 @@ def collapse_isomorphism_group(g): #print(dgl_molecules_dict) + #ADD WRITING MOLECULE LMDB HERE!!! + + log_message("creating molecule entry pickle") # ideally we would serialize mol_entries to a json # some of the auxilary_data we compute