Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kntkb committed Mar 14, 2024
1 parent 87197f8 commit 21ce727
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 4 additions & 0 deletions espfit/app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def output_directory_path(self, value):
def report_loss(self, epoch, loss_dict):
"""Report loss.
This method reports the loss at a given epoch to a log file.
Each loss component is multiplied by 100 for better readability.
Parameters
----------
loss_dict : dict
Expand All @@ -342,6 +345,7 @@ def report_loss(self, epoch, loss_dict):

log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
df_new.mul(100) # Multiple each loss component by 100
df_new.insert(0, 'epoch', epoch)

if os.path.exists(log_file_path):
Expand Down
4 changes: 2 additions & 2 deletions espfit/tests/test_app_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(isomeric=False, keep=True, save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()

ds.reshape_conformation_size(n_confs=50)

return ds


Expand Down
2 changes: 1 addition & 1 deletion espfit/tests/test_app_train_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def test_load_dataset(tmpdir):
# Prepare input dataset ready for training
temporary_directory = tmpdir.mkdir('misc')
ds.drop_duplicates(save_merged_dataset=True, dataset_name='misc', output_directory_path=str(temporary_directory))
ds.reshape_conformation_size(n_confs=50)
ds.compute_relative_energy()
ds.reshape_conformation_size(n_confs=50)

return ds

Expand Down

0 comments on commit 21ce727

Please sign in to comment.