Skip to content

Commit

Permalink
Merge pull request #3 from kntkb/fix/reshape_conformation_size
Browse files Browse the repository at this point in the history
fix _remove_node_features from removing u_ref_relative
  • Loading branch information
kntkb authored Mar 14, 2024
2 parents 7e8b7a9 + d6fea04 commit 205df00
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 5 deletions.
4 changes: 4 additions & 0 deletions espfit/app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def output_directory_path(self, value):
def report_loss(self, epoch, loss_dict):
"""Report loss.
This method reports the loss at a given epoch to a log file.
Each loss component is multiplied by 100 for better readability.
Parameters
----------
loss_dict : dict
Expand All @@ -342,6 +345,7 @@ def report_loss(self, epoch, loss_dict):

log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
df_new = df_new.mul(100) # Multiple each loss component by 100
df_new.insert(0, 'epoch', epoch)

if os.path.exists(log_file_path):
Expand Down
4 changes: 2 additions & 2 deletions espfit/tests/test_app_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(isomeric=False, keep=True, save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()

ds.reshape_conformation_size(n_confs=50)

return ds


Expand Down
2 changes: 1 addition & 1 deletion espfit/tests/test_app_train_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=50)

return ds

Expand Down
2 changes: 2 additions & 0 deletions espfit/tests/test_utils_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,13 +245,15 @@ def test_reshape_conformation_size(mydata_gen2_torsion_sm):
"""
# Test 1) reshape all dgl graphs to have 30 conformations
ds = mydata_gen2_torsion_sm
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=30)
nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds]
assert nconfs == [30, 30, 30, 30, 30, 30, 30, 30], 'All molecules should have 30 conformers'
del ds, nconfs
# Test 2) reshape all dgl graphs to have 30 conformations
mydata = files('espfit').joinpath(paths[0]) # PosixPath
ds = CustomGraphDataset.load(str(mydata))
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=20)
nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds]
assert nconfs == [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20], 'All molecules should have 20 conformers'
Expand Down
8 changes: 6 additions & 2 deletions espfit/utils/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,8 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
graphs into heterogenous graphs with the same number of conformations. This allows shuffling and mini-batching per
graph (molecule).
Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], and g.nodes['n1'].data['xyz'] will be updated.
Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], g.nodes['n1'].data['u_ref_prime'],
and g.nodes['n1'].data['xyz'] will be updated.
Note that this was also intended to augment datasets with fewer molecular diversity but with more conformers from
RNA nucleosides and trinucleotides.
Expand Down Expand Up @@ -572,6 +573,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

_g = copy.deepcopy(g)
_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
_g.nodes['n1'].data['u_ref_prime'] = torch.cat((_g.nodes['n1'].data['u_ref_prime'], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1)
new_graphs.append(_g)
Expand Down Expand Up @@ -603,6 +605,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers")

_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'][:, index], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
_g.nodes["n1"].data["u_ref_prime"] = torch.cat((_g.nodes['n1'].data['u_ref_prime'][:, index, :], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1)
else:
Expand All @@ -617,6 +620,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Extract indice from {idx1} to {idx2}")

_g.nodes["g"].data["u_ref"] = _g.nodes['g'].data['u_ref'][:, index]
_g.nodes["g"].data["u_ref_relative"] = _g.nodes['g'].data['u_ref_relative'][:, index]
_g.nodes["n1"].data["xyz"] = _g.nodes['n1'].data['xyz'][:, index, :]
_g.nodes["n1"].data["u_ref_prime"] = _g.nodes['n1'].data['u_ref_prime'][:, index, :]

Expand All @@ -640,7 +644,7 @@ def _remove_node_features(self):
for g in self.graphs:
_g = copy.deepcopy(g)
for key in g.nodes['g'].data.keys():
if key.startswith('u_') and key != 'u_ref':
if key.startswith('u_') and key != 'u_ref' and key != 'u_ref_relative':
_g.nodes['g'].data.pop(key)
for key in g.nodes['n1'].data.keys():
if key.startswith('u_') and key != 'u_ref_prime':
Expand Down

0 comments on commit 205df00

Please sign in to comment.