diff --git a/HiPRGen/mol_entry.py b/HiPRGen/mol_entry.py index 6bb92bd..3b24309 100644 --- a/HiPRGen/mol_entry.py +++ b/HiPRGen/mol_entry.py @@ -46,9 +46,12 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True): hot_f2_indices.append(fragment_2.hot_atoms[hot_nbh_hash]) hot_f1_indices.append(fragment_1.hot_atoms[hot_nbh_hash]) + #print("hot_f1_indices:", hot_f1_indices) + #print("hot_f2_indices:", hot_f2_indices) for left_index in fragment_1.compressed_graph.nodes(): - + #print("left_index:", left_index) + neighborhood_hash = fragment_1.neighborhood_hashes[left_index] if neighborhood_hash in match_hot_atoms and left_index in hot_f1_indices: neighborhood_hash = neighborhood_hash + "hot" @@ -57,27 +60,40 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True): groups_by_hash[neighborhood_hash][0].append(left_index) + #print("groups_by_hash:", groups_by_hash) for right_index in fragment_2.compressed_graph.nodes(): + #print("right_index:", right_index) neighborhood_hash = fragment_2.neighborhood_hashes[right_index] if neighborhood_hash in match_hot_atoms and right_index in hot_f2_indices: neighborhood_hash = neighborhood_hash + "hot" if neighborhood_hash not in groups_by_hash: + #print("NEW HASH HOW IS THIS POSSIBLE?!!!??!?!") + #print(huh) groups_by_hash[neighborhood_hash] = ([],[]) groups_by_hash[neighborhood_hash][1].append(right_index) + #print("groups_by_hash:", groups_by_hash) + groups = list(groups_by_hash.values()) + #print("groups:", groups) product_sym_iterator = product(*[sym_iterator(len(p[0])) for p in groups]) + #print("product_sym_iterator:", product_sym_iterator) mappings = [] for product_perm in product_sym_iterator: + #print("product_perm:", product_perm) mapping = {} for perm, vals in zip(product_perm, groups): + #print("perm:", perm) + #print("vals:", vals) for i, j in enumerate(perm): + #print("i:", i) + #print("j:", j) mapping[vals[0][i]] = vals[1][j] isomorphism = True diff --git a/HiPRGen/reaction_questions.py b/HiPRGen/reaction_questions.py index 698b50d..28f962e 100644 --- a/HiPRGen/reaction_questions.py +++ b/HiPRGen/reaction_questions.py @@ -60,20 +60,20 @@ class Terminal(Enum): KEEP = 1 DISCARD = -1 hydrogen_graph = nx.MultiGraph() hydrogen_graph.add_node(0, specie="H") -hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie") +hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie", iterations=4) fluorine_graph = nx.MultiGraph() fluorine_graph.add_node(0, specie="F") -fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie") +fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie", iterations=4) carbon_graph = nx.MultiGraph() carbon_graph.add_node(0, specie="C") -carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie") +carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie", iterations=4) co3_mol = Molecule.from_file("xyz_files/co3.xyz") co3_mg = MoleculeGraph.with_local_env_strategy(co3_mol, OpenBabelNN()) co3_g = co3_mg.graph.to_undirected() -co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie") +co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie", iterations=4) def run_decision_tree( @@ -1939,27 +1939,27 @@ def __call__(self, reaction, mol_entries, params): (dG_above_threshold(0.0, "free_energy", 0.0), Terminal.DISCARD), (reactants_are_both_anions_or_both_cations(), Terminal.DISCARD), (reaction_is_charge_transfer(), Terminal.KEEP), - # (reaction_is_covalent_charge_decomposable(), Terminal.DISCARD), - # (reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD), - # (star_count_diff_above_threshold(6), Terminal.DISCARD), - # ( - # fragment_matching_found(), - # [ - # (single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD), - # (single_reactant_double_product_ring_close(), Terminal.DISCARD), - # (reaction_is_hindered(), Terminal.DISCARD), - # ( - # reaction_is_covalent_decomposable(), - # [ - # (fragments_are_not_2A_B(), Terminal.DISCARD), - # (mapping_with_reaction_center_not_found(), Terminal.DISCARD), - # (reaction_default_true(), Terminal.KEEP), - # ], - # ), - # (mapping_with_reaction_center_not_found(), Terminal.DISCARD), - # (reaction_default_true(), Terminal.KEEP), - # ], - # ), + (reaction_is_covalent_charge_decomposable(), Terminal.DISCARD), + (reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD), + (star_count_diff_above_threshold(6), Terminal.DISCARD), + ( + fragment_matching_found(), + [ + (single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD), + (single_reactant_double_product_ring_close(), Terminal.DISCARD), + (reaction_is_hindered(), Terminal.DISCARD), + ( + reaction_is_covalent_decomposable(), + [ + (fragments_are_not_2A_B(), Terminal.DISCARD), + (mapping_with_reaction_center_not_found(), Terminal.DISCARD), + (reaction_default_true(), Terminal.KEEP), + ], + ), + (mapping_with_reaction_center_not_found(), Terminal.DISCARD), + (reaction_default_true(), Terminal.KEEP), + ], + ), (reaction_default_true(), Terminal.DISCARD), ] @@ -1990,4 +1990,4 @@ def __call__(self, reaction, mol_entries, params): ], ), (reaction_default_true(), Terminal.DISCARD), -] \ No newline at end of file +] diff --git a/HiPRGen/species_filter.py b/HiPRGen/species_filter.py index 7f43a32..005d289 100644 --- a/HiPRGen/species_filter.py +++ b/HiPRGen/species_filter.py @@ -213,14 +213,28 @@ def collapse_isomorphism_group(g): log_message("mapping fragments") fragment_dict = {} for mol in mol_entries: - # print(mol.entry_id) + log_message("") + log_message(mol.entry_id) + f = open("mapping_record", "a") + f.write("Mapping " + mol.entry_id + "\n") + f.close() for fragment_complex in mol.fragment_data: + log_message("new fragment complex") + log_message("bonds broken:" + str(fragment_complex.bonds_broken)) for ii, fragment in enumerate(fragment_complex.fragment_objects): + log_message(str(ii)) hot_nbh_hashes = list(fragment.hot_atoms.keys()) + log_message("hot_nbh_hashes:" + str(hot_nbh_hashes)) assert len(hot_nbh_hashes) == 0 or len(hot_nbh_hashes) == 1 or len(hot_nbh_hashes) == 2 assert fragment.fragment_hash == fragment_complex.fragment_hashes[ii] if fragment.fragment_hash not in fragment_dict: + log_message("Adding new fragment! Hash:" + fragment.fragment_hash) fragment_dict[fragment.fragment_hash] = copy.deepcopy(fragment) + else: + log_message("Fragment hash already in dict") + log_message("fragment hash:" + fragment.fragment_hash) + #print("fragment:", fragment) + #print("fragment_dict[fragment.fragment_hash]):", fragment_dict[fragment.fragment_hash]) all_mappings = find_fragment_atom_mappings( fragment, fragment_dict[fragment.fragment_hash]) @@ -407,4 +421,4 @@ def add_electron_species( def clean(input): return "".join([i for i in input if not i.isdigit()]) def clean_op(input): - return "".join([i for i in input if i.isdigit()]) \ No newline at end of file + return "".join([i for i in input if i.isdigit()]) diff --git a/HiPRGen/species_questions.py b/HiPRGen/species_questions.py index 6f5e2c5..07bde14 100644 --- a/HiPRGen/species_questions.py +++ b/HiPRGen/species_questions.py @@ -123,7 +123,7 @@ def __call__(self, mol): ) #covalently bonded to mol.star_hashes[i] = weisfeiler_lehman_graph_hash( #star_hashes is a dictionary, and this adds an entry to it - neighborhood, node_attr="specie" #with the atom index, i, as the key and a graph_hash (string) + neighborhood, node_attr="specie", iterations=6 #with the atom index, i, as the key and a graph_hash (string) ) #as the value return False @@ -151,7 +151,9 @@ def __call__(self, mol): neighborhood_hash = weisfeiler_lehman_graph_hash( neighborhood, - node_attr='specie') + node_attr='specie', + iterations=6, + ) hash_list.append(neighborhood_hash) @@ -196,7 +198,7 @@ def __call__(self, mol): subgraph = h.subgraph(c) #generates a subgraph from one set of nodes (this is a fragment graph) fragment_hash = weisfeiler_lehman_graph_hash( #saves the hash of this graph - subgraph, node_attr="specie" + subgraph, node_attr="specie", iterations=6 ) tmp["c"] = copy.deepcopy(c) @@ -237,7 +239,9 @@ def __call__(self, mol): neighborhood_hash = weisfeiler_lehman_graph_hash( neighborhood, - node_attr='specie') + node_attr='specie', + iterations=6, + ) hash_list.append(neighborhood_hash) @@ -342,7 +346,7 @@ def __call__(self, mol): subgraph = h.subgraph(c) fragment_hash = weisfeiler_lehman_graph_hash( - subgraph, node_attr="specie" + subgraph, node_attr="specie", iterations=6 ) fragments.append(fragment_hash) @@ -389,6 +393,14 @@ def __call__(self, mol): return mol.formula == "H1 O1" and mol.charge == 1 +class formula_filter(MSONable): + def __init__(self, formula): + self.formula = formula + + def __call__(self, mol): + return mol.formula == self.formula + + class fix_hydrogen_bonding(MSONable): def __init__(self): pass @@ -514,10 +526,10 @@ def __call__(self, mol): def compute_graph_hashes(mol): - mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie") + mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie", iterations=6) mol.covalent_hash = weisfeiler_lehman_graph_hash( - mol.covalent_graph, node_attr="specie" + mol.covalent_graph, node_attr="specie", iterations=6 ) return False @@ -608,6 +620,7 @@ def __call__(self, mol): (fix_hydrogen_bonding(), Terminal.KEEP), (h_atom_filter(), Terminal.DISCARD), (oh_plus_filter(), Terminal.DISCARD), + (formula_filter("C36 H28 S2"), Terminal.DISCARD), (compute_graph_hashes, Terminal.KEEP), (add_star_hashes(), Terminal.KEEP), (add_unbroken_fragment(neighborhood_width=width), Terminal.KEEP), diff --git a/test.py b/test.py index 223039b..54a6679 100644 --- a/test.py +++ b/test.py @@ -710,8 +710,23 @@ def euvl_phase2_test(): folder = "./scratch/euvl_phase2_test" subprocess.run(["mkdir", folder]) + bondnet_test_json = "./scratch/euvl_phase2_test/reaction_networks_graphs" + lmdbs_path_mol = "./scratch/euvl_phase2_test/lmdbs/mol" + lmdbs_path_reaction = "./scratch/euvl_phase2_test/lmdbs/reaction" + subprocess.run(["mkdir", bondnet_test_json]) + subprocess.run(["mkdir", "-p",lmdbs_path_mol]) + subprocess.run(["mkdir", "-p",lmdbs_path_reaction]) + mol_json = "./data/euvl_test_set.json" + #mol_json = "./data/problem_entries.json" database_entries = loadfn(mol_json) + #database_entries = [database_entries[2], database_entries[4]] + #database_entries = [database_entries[3], database_entries[4]] + + problem_entries = loadfn("./data/big_entries.json") + for entry in problem_entries: + database_entries.append(entry) + species_decision_tree = euvl_species_decision_tree @@ -720,14 +735,17 @@ def euvl_phase2_test(): "electron_free_energy": 0.0, } - mol_entries = species_filter( + mol_entries, dgl_molecules_dict = species_filter( database_entries, mol_entries_pickle_location=folder + "/mol_entries.pickle", + dgl_mol_grphs_pickle_location = folder + "/dgl_mol_graphs.pickle", + grapher_features_pickle_location= folder + "/grapher_features.pickle", species_report=folder + "/unfiltered_species_report.tex", species_decision_tree=species_decision_tree, coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])), species_logging_decision_tree=species_decision_tree, generate_unfiltered_mol_pictures=False, + mol_lmdb_path = folder + "/lmdbs/mol/mol.lmdb", ) print(len(mol_entries), "initial mol entries") @@ -740,6 +758,7 @@ def euvl_phase2_test(): folder + "/buckets.sqlite", folder + "/rn.sqlite", folder + "/reaction_report.tex", + bondnet_test_json + "/test.json" ) worker_payload = WorkerPayload( @@ -763,6 +782,9 @@ def euvl_phase2_test(): folder + "/mol_entries.pickle", folder + "/dispatcher_payload.json", folder + "/worker_payload.json", + folder + "/dgl_mol_graphs.pickle", + folder + "/grapher_features.pickle", + folder + "/lmdbs/reaction/reaction.lmdb" ] ) @@ -869,7 +891,7 @@ def euvl_phase2_test(): tests_passed = False print("Number of reactions:", network_loader.number_of_reactions) - if network_loader.number_of_reactions == 3912: + if network_loader.number_of_reactions == 3905: print(bcolors.PASS + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC) else: print(bcolors.FAIL + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC) @@ -1007,8 +1029,8 @@ def euvl_bondnet_test(): # flicho_test, # co2_test, # euvl_phase1_test, - # euvl_phase2_test, - euvl_bondnet_test + euvl_phase2_test, + # euvl_bondnet_test ] for test in tests: