Skip to content

Commit

Permalink
fix duplicate logging
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Nov 18, 2024
1 parent 5948537 commit 5180254
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions geo_deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ def after_fit(self):
)

best_model = self.model.__class__.load_from_checkpoint(best_model_path)
test_results = test_trainer.test(model=best_model,
dataloaders=self.datamodule.test_dataloader())
test_metrics = {}
for metric_name, metric_value in test_results[0].items():
test_metrics[f"test_{metric_name}"] = metric_value
self.trainer.logger.log_metrics(test_metrics)
test_trainer.test(model=best_model,
dataloaders=self.datamodule.test_dataloader())
self.trainer.logger.log_hyperparams({"best_model_path": best_model_path})
print("Test metrics logged successfully to all loggers.")
self.trainer.strategy.barrier()
Expand Down

0 comments on commit 5180254

Please sign in to comment.