From 87197f859e2dc1faf4698b9b291f663b61cd54e2 Mon Sep 17 00:00:00 2001 From: kt Date: Wed, 13 Mar 2024 20:14:31 -0400 Subject: [PATCH 1/4] fix _remove_node_features from removing u_ref_relative --- espfit/utils/graphs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/espfit/utils/graphs.py b/espfit/utils/graphs.py index 9673724..bca5582 100644 --- a/espfit/utils/graphs.py +++ b/espfit/utils/graphs.py @@ -525,7 +525,8 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): graphs into heterogenous graphs with the same number of conformations. This allows shuffling and mini-batching per graph (molecule). - Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], and g.nodes['n1'].data['xyz'] will be updated. + Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], g.nodes['n1'].data['u_ref_prime'], + and g.nodes['n1'].data['xyz'] will be updated. Note that this was also intended to augment datasets with fewer molecular diversity but with more conformers from RNA nucleosides and trinucleotides. @@ -572,6 +573,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): _g = copy.deepcopy(g) _g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1) + _g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1) _g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1) _g.nodes['n1'].data['u_ref_prime'] = torch.cat((_g.nodes['n1'].data['u_ref_prime'], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1) new_graphs.append(_g) @@ -603,6 +605,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): _logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers") _g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1) + _g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'][:, index], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1) _g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1) _g.nodes["n1"].data["u_ref_prime"] = torch.cat((_g.nodes['n1'].data['u_ref_prime'][:, index, :], _g.nodes['n1'].data['u_ref_prime'][:, index_random, :]), dim=1) else: @@ -617,6 +620,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): _logger.debug(f"Iteration {j}: Extract indice from {idx1} to {idx2}") _g.nodes["g"].data["u_ref"] = _g.nodes['g'].data['u_ref'][:, index] + _g.nodes["g"].data["u_ref_relative"] = _g.nodes['g'].data['u_ref_relative'][:, index] _g.nodes["n1"].data["xyz"] = _g.nodes['n1'].data['xyz'][:, index, :] _g.nodes["n1"].data["u_ref_prime"] = _g.nodes['n1'].data['u_ref_prime'][:, index, :] @@ -640,7 +644,7 @@ def _remove_node_features(self): for g in self.graphs: _g = copy.deepcopy(g) for key in g.nodes['g'].data.keys(): - if key.startswith('u_') and key != 'u_ref': + if key.startswith('u_') and key != 'u_ref' and key != 'u_ref_relative': _g.nodes['g'].data.pop(key) for key in g.nodes['n1'].data.keys(): if key.startswith('u_') and key != 'u_ref_prime': From 21ce7273126e7a4bce13cd5a917728dbc15c1091 Mon Sep 17 00:00:00 2001 From: kt Date: Wed, 13 Mar 2024 20:41:15 -0400 Subject: [PATCH 2/4] minor fix --- espfit/app/train.py | 4 ++++ espfit/tests/test_app_train.py | 4 ++-- espfit/tests/test_app_train_sampler.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/espfit/app/train.py b/espfit/app/train.py index e4cbfdb..109721d 100644 --- a/espfit/app/train.py +++ b/espfit/app/train.py @@ -329,6 +329,9 @@ def output_directory_path(self, value): def report_loss(self, epoch, loss_dict): """Report loss. + This method reports the loss at a given epoch to a log file. + Each loss component is multiplied by 100 for better readability. + Parameters ---------- loss_dict : dict @@ -342,6 +345,7 @@ def report_loss(self, epoch, loss_dict): log_file_path = os.path.join(self.output_directory_path, 'reporter.log') df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T + df_new.mul(100) # Multiple each loss component by 100 df_new.insert(0, 'epoch', epoch) if os.path.exists(log_file_path): diff --git a/espfit/tests/test_app_train.py b/espfit/tests/test_app_train.py index 1f671c6..81e79c9 100644 --- a/espfit/tests/test_app_train.py +++ b/espfit/tests/test_app_train.py @@ -45,9 +45,9 @@ def test_load_dataset(tmpdir): # Prepare input dataset ready for training temporary_directory = tmpdir.mkdir('misc') ds.drop_duplicates(isomeric=False, keep=True, save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory)) - ds.reshape_conformation_size(n_confs=50) ds.compute_relative_energy() - + ds.reshape_conformation_size(n_confs=50) + return ds diff --git a/espfit/tests/test_app_train_sampler.py b/espfit/tests/test_app_train_sampler.py index 1292f4c..1f71247 100644 --- a/espfit/tests/test_app_train_sampler.py +++ b/espfit/tests/test_app_train_sampler.py @@ -46,8 +46,8 @@ def test_load_dataset(tmpdir): # Prepare input dataset ready for training temporary_directory = tmpdir.mkdir('misc') ds.drop_duplicates(save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory)) - ds.reshape_conformation_size(n_confs=50) ds.compute_relative_energy() + ds.reshape_conformation_size(n_confs=50) return ds From 790df53bc32b029b3564343a1f8f4ea1cc02b508 Mon Sep 17 00:00:00 2001 From: kt Date: Wed, 13 Mar 2024 20:52:45 -0400 Subject: [PATCH 3/4] add compute_relative before reshape_conformation_size --- espfit/tests/test_utils_graphs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/espfit/tests/test_utils_graphs.py b/espfit/tests/test_utils_graphs.py index 0446290..868b338 100644 --- a/espfit/tests/test_utils_graphs.py +++ b/espfit/tests/test_utils_graphs.py @@ -245,6 +245,7 @@ def test_reshape_conformation_size(mydata_gen2_torsion_sm): """ # Test 1) reshape all dgl graphs to have 30 conformations ds = mydata_gen2_torsion_sm + ds.compute_relative_energy() ds.reshape_conformation_size(n_confs=30) nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds] assert nconfs == [30, 30, 30, 30, 30, 30, 30, 30], 'All molecules should have 30 conformers' @@ -252,6 +253,7 @@ def test_reshape_conformation_size(mydata_gen2_torsion_sm): # Test 2) reshape all dgl graphs to have 30 conformations mydata = files('espfit').joinpath(paths[0]) # PosixPath ds = CustomGraphDataset.load(str(mydata)) + ds.compute_relative_energy() ds.reshape_conformation_size(n_confs=20) nconfs = [g.nodes['g'].data['u_ref'].shape[1] for g in ds] assert nconfs == [20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20], 'All molecules should have 20 conformers' From d6fea040775cff4b9a0a90456479b3094031ff18 Mon Sep 17 00:00:00 2001 From: kt Date: Wed, 13 Mar 2024 20:55:32 -0400 Subject: [PATCH 4/4] multiply loss by 100 when exporting log to improve readability --- espfit/app/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/espfit/app/train.py b/espfit/app/train.py index 109721d..8ee0f6d 100644 --- a/espfit/app/train.py +++ b/espfit/app/train.py @@ -345,7 +345,7 @@ def report_loss(self, epoch, loss_dict): log_file_path = os.path.join(self.output_directory_path, 'reporter.log') df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T - df_new.mul(100) # Multiple each loss component by 100 + df_new = df_new.mul(100) # Multiple each loss component by 100 df_new.insert(0, 'epoch', epoch) if os.path.exists(log_file_path):