Skip to content

Commit

Permalink
BonDNet
Browse files Browse the repository at this point in the history
updated the molecule wrapper to use the helper. This way features  are correctly stored for later use in the featurizer
  • Loading branch information
santi921 committed Nov 7, 2023
1 parent 539d552 commit c4f7b0b
Showing 1 changed file with 54 additions and 9 deletions.
63 changes: 54 additions & 9 deletions HiPRGen/species_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
from pymatgen.core.sites import Site
from pymatgen.core.structure import Molecule
from pymatgen.analysis.graphs import MoleculeGraph

from bondnet.model.training_utils import get_grapher
from bondnet.core.molwrapper import MoleculeWrapper
from bondnet.data.transformers import HeteroGraphFeatureStandardScaler

from bondnet.core.molwrapper import create_wrapper_mol_from_atoms_and_bonds
from bondnet.utils import int_atom

"""
Phase 1: species filtering
Expand Down Expand Up @@ -221,34 +223,77 @@ def collapse_isomorphism_group(g):

log_message(str(len(fragment_dict.keys())) + " unique fragments found")


# Make DGL Molecule graphs via BonDNet functions
log_message("creating dgl molecule graphs")
dgl_molecules_dict = {}
dgl_molecules = []
extra_keys = []


# BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS
for mol in mol_entries:
# print(f"mol: {mol.mol_graph}")
molecule_grapher = get_grapher(extra_keys)

non_metal_bonds = [ (i, j) for i, j, _ in mol.covalent_graph.edges.data()]
molecule_grapher = get_grapher(
features = extra_keys,
allowed_charges=[-2,-1,0,1,2],
global_feats=["charge"],
) # same
non_metal_bonds = [ (i, j) for i, j, _ in mol.covalent_graph.edges.data()] # same

# print(f"non metal bonds: {non_metal_bonds}")
mol_wrapper = MoleculeWrapper(mol_graph = mol.mol_graph, free_energy = None, id = mol.entry_id, non_metal_bonds = non_metal_bonds)
# use create molecule wrapper instead here
#mol_wrapper = MoleculeWrapper(
# mol_graph = mol.mol_graph,
# free_energy = None, id = mol.entry_id,
# non_metal_bonds = non_metal_bonds,
# extra_keys = extra_keys
#)
#print(mol.mol_graph.molecule.sites)
#print(mol.mol_graph.graph)
#print(mol.mol_graph.graph.edges())

species = [i.specie for i in mol.mol_graph.molecule.sites]
coords = [i.coords for i in mol.mol_graph.molecule.sites]

bonds = [
[i[0], i[1]] for i in mol.mol_graph.graph.edges()
]

mol_wrapper = create_wrapper_mol_from_atoms_and_bonds(
species,
coords,
bonds,
charge = mol.charge,
functional_group=None,
identifier=mol.entry_id,
original_atom_ind=None,
original_bond_ind=None,
atom_features=None,
bond_features=None,
global_features={"charge": mol.charge}
)
mol_wrapper.nonmetal_bonds = non_metal_bonds
feature = {'charge': mol.charge}
dgl_molecule_graph = molecule_grapher.build_graph_and_featurize(mol_wrapper, extra_feats_info = feature, dataset_species = elements)
dgl_molecule_graph = molecule_grapher.build_graph_and_featurize(
mol_wrapper,
extra_feats_info = feature,
element_set = elements
)
dgl_molecules.append(dgl_molecule_graph)
for nt in ["global", "atom", "bond"]:
print(f"nt: {nt}")
fts = dgl_molecule_graph.nodes[nt].data["feat"]
print(f"features: {fts}")
dgl_molecules_dict[mol.entry_id] = mol.ind
print(molecule_grapher.feature_name)
grapher_features= {'feature_size':molecule_grapher.feature_size, 'feature_name': molecule_grapher.feature_name}
#mol_wrapper_dict[mol.entry_id] = mol_wrapper

# Normalize DGL molecule graphs
scaler = HeteroGraphFeatureStandardScaler(mean = None, std = None)
normalized_graphs = scaler(dgl_molecules)

# BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS # BONDNET EDITS

# print(f"mean: {scaler._mean}")
# print(f"std: {scaler._std}")

Expand Down Expand Up @@ -307,4 +352,4 @@ def add_electron_species(
mol_entries.append(electron_entry)
with open(mol_entries_pickle_location, "wb") as f:
pickle.dump(mol_entries, f)
return mol_entries
return mol_entries

0 comments on commit c4f7b0b

Please sign in to comment.