Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debugging and fixing production #18

Merged
merged 1 commit into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion HiPRGen/mol_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True):
hot_f2_indices.append(fragment_2.hot_atoms[hot_nbh_hash])
hot_f1_indices.append(fragment_1.hot_atoms[hot_nbh_hash])

#print("hot_f1_indices:", hot_f1_indices)
#print("hot_f2_indices:", hot_f2_indices)

for left_index in fragment_1.compressed_graph.nodes():

#print("left_index:", left_index)

neighborhood_hash = fragment_1.neighborhood_hashes[left_index]
if neighborhood_hash in match_hot_atoms and left_index in hot_f1_indices:
neighborhood_hash = neighborhood_hash + "hot"
Expand All @@ -57,27 +60,40 @@ def find_fragment_atom_mappings(fragment_1, fragment_2, return_one=True):

groups_by_hash[neighborhood_hash][0].append(left_index)

#print("groups_by_hash:", groups_by_hash)

for right_index in fragment_2.compressed_graph.nodes():
#print("right_index:", right_index)

neighborhood_hash = fragment_2.neighborhood_hashes[right_index]
if neighborhood_hash in match_hot_atoms and right_index in hot_f2_indices:
neighborhood_hash = neighborhood_hash + "hot"
if neighborhood_hash not in groups_by_hash:
#print("NEW HASH HOW IS THIS POSSIBLE?!!!??!?!")
#print(huh)
groups_by_hash[neighborhood_hash] = ([],[])

groups_by_hash[neighborhood_hash][1].append(right_index)

#print("groups_by_hash:", groups_by_hash)

groups = list(groups_by_hash.values())
#print("groups:", groups)

product_sym_iterator = product(*[sym_iterator(len(p[0])) for p in groups])

#print("product_sym_iterator:", product_sym_iterator)
mappings = []

for product_perm in product_sym_iterator:
#print("product_perm:", product_perm)
mapping = {}
for perm, vals in zip(product_perm, groups):
#print("perm:", perm)
#print("vals:", vals)
for i, j in enumerate(perm):
#print("i:", i)
#print("j:", j)
mapping[vals[0][i]] = vals[1][j]

isomorphism = True
Expand Down
52 changes: 26 additions & 26 deletions HiPRGen/reaction_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ class Terminal(Enum): KEEP = 1 DISCARD = -1

hydrogen_graph = nx.MultiGraph()
hydrogen_graph.add_node(0, specie="H")
hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie")
hydrogen_hash = weisfeiler_lehman_graph_hash(hydrogen_graph, node_attr="specie", iterations=4)

fluorine_graph = nx.MultiGraph()
fluorine_graph.add_node(0, specie="F")
fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie")
fluorine_hash = weisfeiler_lehman_graph_hash(fluorine_graph, node_attr="specie", iterations=4)

carbon_graph = nx.MultiGraph()
carbon_graph.add_node(0, specie="C")
carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie")
carbon_hash = weisfeiler_lehman_graph_hash(carbon_graph, node_attr="specie", iterations=4)

co3_mol = Molecule.from_file("xyz_files/co3.xyz")
co3_mg = MoleculeGraph.with_local_env_strategy(co3_mol, OpenBabelNN())
co3_g = co3_mg.graph.to_undirected()
co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie")
co3_hash = weisfeiler_lehman_graph_hash(co3_g, node_attr="specie", iterations=4)


def run_decision_tree(
Expand Down Expand Up @@ -1939,27 +1939,27 @@ def __call__(self, reaction, mol_entries, params):
(dG_above_threshold(0.0, "free_energy", 0.0), Terminal.DISCARD),
(reactants_are_both_anions_or_both_cations(), Terminal.DISCARD),
(reaction_is_charge_transfer(), Terminal.KEEP),
# (reaction_is_covalent_charge_decomposable(), Terminal.DISCARD),
# (reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD),
# (star_count_diff_above_threshold(6), Terminal.DISCARD),
# (
# fragment_matching_found(),
# [
# (single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD),
# (single_reactant_double_product_ring_close(), Terminal.DISCARD),
# (reaction_is_hindered(), Terminal.DISCARD),
# (
# reaction_is_covalent_decomposable(),
# [
# (fragments_are_not_2A_B(), Terminal.DISCARD),
# (mapping_with_reaction_center_not_found(), Terminal.DISCARD),
# (reaction_default_true(), Terminal.KEEP),
# ],
# ),
# (mapping_with_reaction_center_not_found(), Terminal.DISCARD),
# (reaction_default_true(), Terminal.KEEP),
# ],
# ),
(reaction_is_covalent_charge_decomposable(), Terminal.DISCARD),
(reaction_is_coupled_electron_fragment_transfer(), Terminal.DISCARD),
(star_count_diff_above_threshold(6), Terminal.DISCARD),
(
fragment_matching_found(),
[
(single_reactant_single_product_not_atom_transfer(), Terminal.DISCARD),
(single_reactant_double_product_ring_close(), Terminal.DISCARD),
(reaction_is_hindered(), Terminal.DISCARD),
(
reaction_is_covalent_decomposable(),
[
(fragments_are_not_2A_B(), Terminal.DISCARD),
(mapping_with_reaction_center_not_found(), Terminal.DISCARD),
(reaction_default_true(), Terminal.KEEP),
],
),
(mapping_with_reaction_center_not_found(), Terminal.DISCARD),
(reaction_default_true(), Terminal.KEEP),
],
),
(reaction_default_true(), Terminal.DISCARD),
]

Expand Down Expand Up @@ -1990,4 +1990,4 @@ def __call__(self, reaction, mol_entries, params):
],
),
(reaction_default_true(), Terminal.DISCARD),
]
]
18 changes: 16 additions & 2 deletions HiPRGen/species_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,28 @@ def collapse_isomorphism_group(g):
log_message("mapping fragments")
fragment_dict = {}
for mol in mol_entries:
# print(mol.entry_id)
log_message("")
log_message(mol.entry_id)
f = open("mapping_record", "a")
f.write("Mapping " + mol.entry_id + "\n")
f.close()
for fragment_complex in mol.fragment_data:
log_message("new fragment complex")
log_message("bonds broken:" + str(fragment_complex.bonds_broken))
for ii, fragment in enumerate(fragment_complex.fragment_objects):
log_message(str(ii))
hot_nbh_hashes = list(fragment.hot_atoms.keys())
log_message("hot_nbh_hashes:" + str(hot_nbh_hashes))
assert len(hot_nbh_hashes) == 0 or len(hot_nbh_hashes) == 1 or len(hot_nbh_hashes) == 2
assert fragment.fragment_hash == fragment_complex.fragment_hashes[ii]
if fragment.fragment_hash not in fragment_dict:
log_message("Adding new fragment! Hash:" + fragment.fragment_hash)
fragment_dict[fragment.fragment_hash] = copy.deepcopy(fragment)
else:
log_message("Fragment hash already in dict")
log_message("fragment hash:" + fragment.fragment_hash)
#print("fragment:", fragment)
#print("fragment_dict[fragment.fragment_hash]):", fragment_dict[fragment.fragment_hash])
all_mappings = find_fragment_atom_mappings(
fragment,
fragment_dict[fragment.fragment_hash])
Expand Down Expand Up @@ -407,4 +421,4 @@ def add_electron_species(
def clean(input):
return "".join([i for i in input if not i.isdigit()])
def clean_op(input):
return "".join([i for i in input if i.isdigit()])
return "".join([i for i in input if i.isdigit()])
27 changes: 20 additions & 7 deletions HiPRGen/species_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __call__(self, mol):
) #covalently bonded to

mol.star_hashes[i] = weisfeiler_lehman_graph_hash( #star_hashes is a dictionary, and this adds an entry to it
neighborhood, node_attr="specie" #with the atom index, i, as the key and a graph_hash (string)
neighborhood, node_attr="specie", iterations=6 #with the atom index, i, as the key and a graph_hash (string)
) #as the value

return False
Expand Down Expand Up @@ -151,7 +151,9 @@ def __call__(self, mol):

neighborhood_hash = weisfeiler_lehman_graph_hash(
neighborhood,
node_attr='specie')
node_attr='specie',
iterations=6,
)

hash_list.append(neighborhood_hash)

Expand Down Expand Up @@ -196,7 +198,7 @@ def __call__(self, mol):
subgraph = h.subgraph(c) #generates a subgraph from one set of nodes (this is a fragment graph)

fragment_hash = weisfeiler_lehman_graph_hash( #saves the hash of this graph
subgraph, node_attr="specie"
subgraph, node_attr="specie", iterations=6
)

tmp["c"] = copy.deepcopy(c)
Expand Down Expand Up @@ -237,7 +239,9 @@ def __call__(self, mol):

neighborhood_hash = weisfeiler_lehman_graph_hash(
neighborhood,
node_attr='specie')
node_attr='specie',
iterations=6,
)

hash_list.append(neighborhood_hash)

Expand Down Expand Up @@ -342,7 +346,7 @@ def __call__(self, mol):
subgraph = h.subgraph(c)

fragment_hash = weisfeiler_lehman_graph_hash(
subgraph, node_attr="specie"
subgraph, node_attr="specie", iterations=6
)

fragments.append(fragment_hash)
Expand Down Expand Up @@ -389,6 +393,14 @@ def __call__(self, mol):
return mol.formula == "H1 O1" and mol.charge == 1


class formula_filter(MSONable):
def __init__(self, formula):
self.formula = formula

def __call__(self, mol):
return mol.formula == self.formula


class fix_hydrogen_bonding(MSONable):
def __init__(self):
pass
Expand Down Expand Up @@ -514,10 +526,10 @@ def __call__(self, mol):


def compute_graph_hashes(mol):
mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie")
mol.total_hash = weisfeiler_lehman_graph_hash(mol.graph, node_attr="specie", iterations=6)

mol.covalent_hash = weisfeiler_lehman_graph_hash(
mol.covalent_graph, node_attr="specie"
mol.covalent_graph, node_attr="specie", iterations=6
)

return False
Expand Down Expand Up @@ -608,6 +620,7 @@ def __call__(self, mol):
(fix_hydrogen_bonding(), Terminal.KEEP),
(h_atom_filter(), Terminal.DISCARD),
(oh_plus_filter(), Terminal.DISCARD),
(formula_filter("C36 H28 S2"), Terminal.DISCARD),
(compute_graph_hashes, Terminal.KEEP),
(add_star_hashes(), Terminal.KEEP),
(add_unbroken_fragment(neighborhood_width=width), Terminal.KEEP),
Expand Down
30 changes: 26 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,23 @@ def euvl_phase2_test():
folder = "./scratch/euvl_phase2_test"
subprocess.run(["mkdir", folder])

bondnet_test_json = "./scratch/euvl_phase2_test/reaction_networks_graphs"
lmdbs_path_mol = "./scratch/euvl_phase2_test/lmdbs/mol"
lmdbs_path_reaction = "./scratch/euvl_phase2_test/lmdbs/reaction"
subprocess.run(["mkdir", bondnet_test_json])
subprocess.run(["mkdir", "-p",lmdbs_path_mol])
subprocess.run(["mkdir", "-p",lmdbs_path_reaction])

mol_json = "./data/euvl_test_set.json"
#mol_json = "./data/problem_entries.json"
database_entries = loadfn(mol_json)
#database_entries = [database_entries[2], database_entries[4]]
#database_entries = [database_entries[3], database_entries[4]]

problem_entries = loadfn("./data/big_entries.json")
for entry in problem_entries:
database_entries.append(entry)


species_decision_tree = euvl_species_decision_tree

Expand All @@ -720,14 +735,17 @@ def euvl_phase2_test():
"electron_free_energy": 0.0,
}

mol_entries = species_filter(
mol_entries, dgl_molecules_dict = species_filter(
database_entries,
mol_entries_pickle_location=folder + "/mol_entries.pickle",
dgl_mol_grphs_pickle_location = folder + "/dgl_mol_graphs.pickle",
grapher_features_pickle_location= folder + "/grapher_features.pickle",
species_report=folder + "/unfiltered_species_report.tex",
species_decision_tree=species_decision_tree,
coordimer_weight=lambda mol: (mol.get_free_energy(params["temperature"])),
species_logging_decision_tree=species_decision_tree,
generate_unfiltered_mol_pictures=False,
mol_lmdb_path = folder + "/lmdbs/mol/mol.lmdb",
)

print(len(mol_entries), "initial mol entries")
Expand All @@ -740,6 +758,7 @@ def euvl_phase2_test():
folder + "/buckets.sqlite",
folder + "/rn.sqlite",
folder + "/reaction_report.tex",
bondnet_test_json + "/test.json"
)

worker_payload = WorkerPayload(
Expand All @@ -763,6 +782,9 @@ def euvl_phase2_test():
folder + "/mol_entries.pickle",
folder + "/dispatcher_payload.json",
folder + "/worker_payload.json",
folder + "/dgl_mol_graphs.pickle",
folder + "/grapher_features.pickle",
folder + "/lmdbs/reaction/reaction.lmdb"
]
)

Expand Down Expand Up @@ -869,7 +891,7 @@ def euvl_phase2_test():
tests_passed = False

print("Number of reactions:", network_loader.number_of_reactions)
if network_loader.number_of_reactions == 3912:
if network_loader.number_of_reactions == 3905:
print(bcolors.PASS + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC)
else:
print(bcolors.FAIL + "euvl_phase_2_test: correct number of reactions" + bcolors.ENDC)
Expand Down Expand Up @@ -1007,8 +1029,8 @@ def euvl_bondnet_test():
# flicho_test,
# co2_test,
# euvl_phase1_test,
# euvl_phase2_test,
euvl_bondnet_test
euvl_phase2_test,
# euvl_bondnet_test
]

for test in tests:
Expand Down
Loading