Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _remove_node_features from removing u_ref_relative #3

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions espfit/app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def output_directory_path(self, value):
def report_loss(self, epoch, loss_dict):
"""Report loss.

This method reports the loss at a given epoch to a log file.
Each loss component is multiplied by 100 for better readability.

Parameters
----------
loss_dict : dict
Expand All @@ -342,6 +345,7 @@ def report_loss(self, epoch, loss_dict):

log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
df_new = df_new.mul(100) # Multiple each loss component by 100
df_new.insert(0, 'epoch', epoch)

if os.path.exists(log_file_path):
Expand Down
4 changes: 2 additions & 2 deletions espfit/tests/test_app_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(isomeric=False, keep=True, save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()

ds.reshape_conformation_size(n_confs=50)

return ds


Expand Down
2 changes: 1 addition & 1 deletion espfit/tests/test_app_train_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=50)

return ds

Expand Down
2 changes: 2 additions & 0 deletions espfit/tests/test_utils_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def test_reshape_conformation_size(mydata_gen2_torsion_sm):
"""
# Test 1) reshape all dgl graphs to have 30 conformations
ds = mydata_gen2_torsion_sm
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=30)
nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds]
assert nconfs == [30, 30, 30, 30, 30, 30, 30, 30], 'All molecules should have 30 conformers'
del ds, nconfs
# Test 2) reshape all dgl graphs to have 30 conformations
mydata = files('espfit').joinpath(paths[0]) # PosixPath
ds = CustomGraphDataset.load(str(mydata))
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=20)
nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds]
assert nconfs == [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20], 'All molecules should have 20 conformers'
Expand Down
8 changes: 6 additions & 2 deletions espfit/utils/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
graphs into heterogenous graphs with the same number of conformations. This allows shuffling and mini-batching per
graph (molecule).

Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], and g.nodes['n1'].data['xyz'] will be updated.
Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], g.nodes['n1'].data['u_ref_prime'],
and g.nodes['n1'].data['xyz'] will be updated.

Note that this was also intended to augment datasets with fewer molecular diversity but with more conformers from
RNA nucleosides and trinucleotides.
Expand Down Expand Up @@ -572,6 +573,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

_g = copy.deepcopy(g)
_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
_g.nodes['n1'].data['u_ref_prime'] = torch.cat((_g.nodes['n1'].data['u_ref_prime'], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1)
new_graphs.append(_g)
Expand Down Expand Up @@ -603,6 +605,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers")

_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'][:, index], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
_g.nodes["n1"].data["u_ref_prime"] = torch.cat((_g.nodes['n1'].data['u_ref_prime'][:, index, :], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1)
else:
Expand All @@ -617,6 +620,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Extract indice from {idx1} to {idx2}")

_g.nodes["g"].data["u_ref"] = _g.nodes['g'].data['u_ref'][:, index]
_g.nodes["g"].data["u_ref_relative"] = _g.nodes['g'].data['u_ref_relative'][:, index]
_g.nodes["n1"].data["xyz"] = _g.nodes['n1'].data['xyz'][:, index, :]
_g.nodes["n1"].data["u_ref_prime"] = _g.nodes['n1'].data['u_ref_prime'][:, index, :]

Expand All @@ -640,7 +644,7 @@ def _remove_node_features(self):
for g in self.graphs:
_g = copy.deepcopy(g)
for key in g.nodes['g'].data.keys():
if key.startswith('u_') and key != 'u_ref':
if key.startswith('u_') and key != 'u_ref' and key != 'u_ref_relative':
_g.nodes['g'].data.pop(key)
for key in g.nodes['n1'].data.keys():
if key.startswith('u_') and key != 'u_ref_prime':
Expand Down
Loading