Skip to content

Commit

Permalink
Merge pull request #18 from samblau/BonDNet
Browse files Browse the repository at this point in the history
Debugging and fixing production
  • Loading branch information
samblau authored Dec 24, 2023
2 parents 182e0a6 + 5089769 commit a5ea643
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 40 deletions.
18 changes: 17 additions & 1 deletion HiPRGen/mol_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True):
hot_f2_indices.append(fragment_2.hot_atoms[hot_nbh_hash])
hot_f1_indices.append(fragment_1.hot_atoms[hot_nbh_hash])

#print("hot_f1_indices:", hot_f1_indices)
#print("hot_f2_indices:", hot_f2_indices)

for left_index in fragment_1.compressed_graph.nodes():

#print("left_index:", left_index)

neighborhood_hash = fragment_1.neighborhood_hashes[left_index]
if neighborhood_hash in match_hot_atoms and left_index in hot_f1_indices:
neighborhood_hash = neighborhood_hash + "hot"
Expand All @@ -57,27 +60,40 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True):

groups_by_hash[neighborhood_hash][0].append(left_index)

#print("groups_by_hash:", groups_by_hash)

for right_index in fragment_2.compressed_graph.nodes():
#print("right_index:", right_index)

neighborhood_hash = fragment_2.neighborhood_hashes[right_index]
if neighborhood_hash in match_hot_atoms and right_index in hot_f2_indices:
neighborhood_hash = neighborhood_hash + "hot"
if neighborhood_hash not in groups_by_hash:
#print("NEW HASH HOW IS THIS POSSIBLE?!!!??!?!")
#print(huh)
groups_by_hash[neighborhood_hash] = ([],[])

groups_by_hash[neighborhood_hash][1].append(right_index)

#print("groups_by_hash:", groups_by_hash)

groups = list(groups_by_hash.values())
#print("groups:", groups)

product_sym_iterator = product(*[sym_iterator(len(p[0])) for p in groups])

#print("product_sym_iterator:", product_sym_iterator)
mappings = []

for product_perm in product_sym_iterator:
#print("product_perm:", product_perm)
mapping = {}
for perm, vals in zip(product_perm, groups):
#print("perm:", perm)
#print("vals:", vals)
for i, j in enumerate(perm):
#print("i:", i)
#print("j:", j)
mapping[vals[0][i]] = vals[1][j]

isomorphism = True
Expand Down
52 changes: 26 additions & 26 deletions HiPRGen/reaction_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ class Terminal(Enum): KEEP = 1 DISCARD = -1

hydrogen_graph = nx.MultiGraph()
hydrogen_graph.add_node(0, specie="H")
hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie")
hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie", iterations=4)

fluorine_graph = nx.MultiGraph()
fluorine_graph.add_node(0, specie="F")
fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie")
fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie", iterations=4)

carbon_graph = nx.MultiGraph()
carbon_graph.add_node(0, specie="C")
carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie")
carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie", iterations=4)

co3_mol = Molecule.from_file("xyz_files/co3.xyz")
co3_mg = MoleculeGraph.with_local_env_strategy(co3_mol, OpenBabelNN())
co3_g = co3_mg.graph.to_undirected()
co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie")
co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie", iterations=4)


def run_decision_tree(
Expand Down Expand Up @@ -1939,27 +1939,27 @@ def __call__(self, reaction, mol_entries, params):
(dG_above_threshold(0.0, "free_energy", 0.0), Terminal.DISCARD),
(reactants_are_both_anions_or_both_cations(), Terminal.DISCARD),
(reaction_is_charge_transfer(), Terminal.KEEP),
# (reaction_is_covalent_charge_decomposable(), Terminal.DISCARD),
# (reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD),
# (star_count_diff_above_threshold(6), Terminal.DISCARD),
# (
# fragment_matching_found(),
# [
# (single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD),
# (single_reactant_double_product_ring_close(), Terminal.DISCARD),
# (reaction_is_hindered(), Terminal.DISCARD),
# (
# reaction_is_covalent_decomposable(),
# [
# (fragments_are_not_2A_B(), Terminal.DISCARD),
# (mapping_with_reaction_center_not_found(), Terminal.DISCARD),
# (reaction_default_true(), Terminal.KEEP),
# ],
# ),
# (mapping_with_reaction_center_not_found(), Terminal.DISCARD),
# (reaction_default_true(), Terminal.KEEP),
# ],
# ),
(reaction_is_covalent_charge_decomposable(), Terminal.DISCARD),
(reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD),
(star_count_diff_above_threshold(6), Terminal.DISCARD),
(
fragment_matching_found(),
[
(single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD),
(single_reactant_double_product_ring_close(), Terminal.DISCARD),
(reaction_is_hindered(), Terminal.DISCARD),
(
reaction_is_covalent_decomposable(),
[
(fragments_are_not_2A_B(), Terminal.DISCARD),
(mapping_with_reaction_center_not_found(), Terminal.DISCARD),
(reaction_default_true(), Terminal.KEEP),
],
),
(mapping_with_reaction_center_not_found(), Terminal.DISCARD),
(reaction_default_true(), Terminal.KEEP),
],
),
(reaction_default_true(), Terminal.DISCARD),
]

Expand Down Expand Up @@ -1990,4 +1990,4 @@ def __call__(self, reaction, mol_entries, params):
],
),
(reaction_default_true(), Terminal.DISCARD),
]
]
18 changes: 16 additions & 2 deletions HiPRGen/species_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,28 @@ def collapse_isomorphism_group(g):
log_message("mapping fragments")
fragment_dict = {}
for mol in mol_entries:
# print(mol.entry_id)
log_message("")
log_message(mol.entry_id)
f = open("mapping_record", "a")
f.write("Mapping " + mol.entry_id + "\n")
f.close()
for fragment_complex in mol.fragment_data:
log_message("new fragment complex")
log_message("bonds broken:" + str(fragment_complex.bonds_broken))
for ii, fragment in enumerate(fragment_complex.fragment_objects):
log_message(str(ii))
hot_nbh_hashes = list(fragment.hot_atoms.keys())
log_message("hot_nbh_hashes:" + str(hot_nbh_hashes))
assert len(hot_nbh_hashes) == 0 or len(hot_nbh_hashes) == 1 or len(hot_nbh_hashes) == 2
assert fragment.fragment_hash == fragment_complex.fragment_hashes[ii]
if fragment.fragment_hash not in fragment_dict:
log_message("Adding new fragment! Hash:" + fragment.fragment_hash)
fragment_dict[fragment.fragment_hash] = copy.deepcopy(fragment)
else:
log_message("Fragment hash already in dict")
log_message("fragment hash:" + fragment.fragment_hash)
#print("fragment:", fragment)
#print("fragment_dict[fragment.fragment_hash]):", fragment_dict[fragment.fragment_hash])
all_mappings = find_fragment_atom_mappings(
fragment,
fragment_dict[fragment.fragment_hash])
Expand Down Expand Up @@ -407,4 +421,4 @@ def add_electron_species(
def clean(input):
return "".join([i for i in input if not i.isdigit()])
def clean_op(input):
return "".join([i for i in input if i.isdigit()])
return "".join([i for i in input if i.isdigit()])
27 changes: 20 additions & 7 deletions HiPRGen/species_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __call__(self, mol):
) #covalently bonded to

mol.star_hashes[i] = weisfeiler_lehman_graph_hash( #star_hashes is a dictionary, and this adds an entry to it
neighborhood, node_attr="specie" #with the atom index, i, as the key and a graph_hash (string)
neighborhood, node_attr="specie", iterations=6 #with the atom index, i, as the key and a graph_hash (string)
) #as the value

return False
Expand Down Expand Up @@ -151,7 +151,9 @@ def __call__(self, mol):

neighborhood_hash = weisfeiler_lehman_graph_hash(
neighborhood,
node_attr='specie')
node_attr='specie',
iterations=6,
)

hash_list.append(neighborhood_hash)

Expand Down Expand Up @@ -196,7 +198,7 @@ def __call__(self, mol):
subgraph = h.subgraph(c) #generates a subgraph from one set of nodes (this is a fragment graph)

fragment_hash = weisfeiler_lehman_graph_hash( #saves the hash of this graph
subgraph, node_attr="specie"
subgraph, node_attr="specie", iterations=6
)

tmp["c"] = copy.deepcopy(c)
Expand Down Expand Up @@ -237,7 +239,9 @@ def __call__(self, mol):

neighborhood_hash = weisfeiler_lehman_graph_hash(
neighborhood,
node_attr='specie')
node_attr='specie',
iterations=6,
)

hash_list.append(neighborhood_hash)

Expand Down Expand Up @@ -342,7 +346,7 @@ def __call__(self, mol):
subgraph = h.subgraph(c)

fragment_hash = weisfeiler_lehman_graph_hash(
subgraph, node_attr="specie"
subgraph, node_attr="specie", iterations=6
)

fragments.append(fragment_hash)
Expand Down Expand Up @@ -389,6 +393,14 @@ def __call__(self, mol):
return mol.formula == "H1 O1" and mol.charge == 1


class formula_filter(MSONable):
def __init__(self, formula):
self.formula = formula

def __call__(self, mol):
return mol.formula == self.formula


class fix_hydrogen_bonding(MSONable):
def __init__(self):
pass
Expand Down Expand Up @@ -514,10 +526,10 @@ def __call__(self, mol):


def compute_graph_hashes(mol):
mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie")
mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie", iterations=6)

mol.covalent_hash = weisfeiler_lehman_graph_hash(
mol.covalent_graph, node_attr="specie"
mol.covalent_graph, node_attr="specie", iterations=6
)

return False
Expand Down Expand Up @@ -608,6 +620,7 @@ def __call__(self, mol):
(fix_hydrogen_bonding(), Terminal.KEEP),
(h_atom_filter(), Terminal.DISCARD),
(oh_plus_filter(), Terminal.DISCARD),
(formula_filter("C36 H28 S2"), Terminal.DISCARD),
(compute_graph_hashes, Terminal.KEEP),
(add_star_hashes(), Terminal.KEEP),
(add_unbroken_fragment(neighborhood_width=width), Terminal.KEEP),
Expand Down
30 changes: 26 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,23 @@ def euvl_phase2_test():
folder = "./scratch/euvl_phase2_test"
subprocess.run(["mkdir", folder])

bondnet_test_json = "./scratch/euvl_phase2_test/reaction_networks_graphs"
lmdbs_path_mol = "./scratch/euvl_phase2_test/lmdbs/mol"
lmdbs_path_reaction = "./scratch/euvl_phase2_test/lmdbs/reaction"
subprocess.run(["mkdir", bondnet_test_json])
subprocess.run(["mkdir", "-p",lmdbs_path_mol])
subprocess.run(["mkdir", "-p",lmdbs_path_reaction])

mol_json = "./data/euvl_test_set.json"
#mol_json = "./data/problem_entries.json"
database_entries = loadfn(mol_json)
#database_entries = [database_entries[2], database_entries[4]]
#database_entries = [database_entries[3], database_entries[4]]

problem_entries = loadfn("./data/big_entries.json")
for entry in problem_entries:
database_entries.append(entry)


species_decision_tree = euvl_species_decision_tree

Expand All @@ -720,14 +735,17 @@ def euvl_phase2_test():
"electron_free_energy": 0.0,
}

mol_entries = species_filter(
mol_entries, dgl_molecules_dict = species_filter(
database_entries,
mol_entries_pickle_location=folder + "/mol_entries.pickle",
dgl_mol_grphs_pickle_location = folder + "/dgl_mol_graphs.pickle",
grapher_features_pickle_location= folder + "/grapher_features.pickle",
species_report=folder + "/unfiltered_species_report.tex",
species_decision_tree=species_decision_tree,
coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])),
species_logging_decision_tree=species_decision_tree,
generate_unfiltered_mol_pictures=False,
mol_lmdb_path = folder + "/lmdbs/mol/mol.lmdb",
)

print(len(mol_entries), "initial mol entries")
Expand All @@ -740,6 +758,7 @@ def euvl_phase2_test():
folder + "/buckets.sqlite",
folder + "/rn.sqlite",
folder + "/reaction_report.tex",
bondnet_test_json + "/test.json"
)

worker_payload = WorkerPayload(
Expand All @@ -763,6 +782,9 @@ def euvl_phase2_test():
folder + "/mol_entries.pickle",
folder + "/dispatcher_payload.json",
folder + "/worker_payload.json",
folder + "/dgl_mol_graphs.pickle",
folder + "/grapher_features.pickle",
folder + "/lmdbs/reaction/reaction.lmdb"
]
)

Expand Down Expand Up @@ -869,7 +891,7 @@ def euvl_phase2_test():
tests_passed = False

print("Number of reactions:", network_loader.number_of_reactions)
if network_loader.number_of_reactions == 3912:
if network_loader.number_of_reactions == 3905:
print(bcolors.PASS + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC)
else:
print(bcolors.FAIL + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC)
Expand Down Expand Up @@ -1007,8 +1029,8 @@ def euvl_bondnet_test():
# flicho_test,
# co2_test,
# euvl_phase1_test,
# euvl_phase2_test,
euvl_bondnet_test
euvl_phase2_test,
# euvl_bondnet_test
]

for test in tests:
Expand Down

0 comments on commit a5ea643

Please sign in to comment.