Skip to content

Commit

Permalink
Refactor train.py to log test metrics using logger
Browse files Browse the repository at this point in the history
  • Loading branch information
valhassan committed Sep 27, 2024
1 parent deca95a commit b54ff80
Showing 1 changed file with 15 additions and 33 deletions.
48 changes: 15 additions & 33 deletions geo_deep_learning/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def after_fit(self):
strategy="auto")
best_model = self.model.__class__.load_from_checkpoint(best_model_path)
test_results = test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader())
self.log_test_metrics_to_mlflow(test_results)
# self.log_test_metrics(test_results)
self.log_test_metrics(test_results)
self.trainer.strategy.barrier()

def print_dataset_sizes(self):
Expand All @@ -33,39 +32,22 @@ def print_dataset_sizes(self):
print(f"Number of validation samples: {val_size}")
print(f"Number of test samples: {test_size}")

def log_test_metrics_to_mlflow(self, test_results):
# Get the MLflow logger from the trainer
mlf_logger = next((logger for logger in self.trainer.loggers if isinstance(logger, MLFlowLogger)), None)

if mlf_logger is not None:
# Log each metric from the test results
def log_test_metrics(self, test_results):
if not self.trainer.logger:
print("No logger found. Test metrics will not be logged.")
return
if isinstance(self.trainer.logger, Logger):
for metric_name, metric_value in test_results[0].items():
mlf_logger.experiment.log_metric(
run_id=mlf_logger.run_id,
key=f"test_{metric_name}",
value=metric_value
)
print("Test metrics logged to MLflow successfully.")
self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value})
print("Test metrics logged successfully.")
elif isinstance(self.trainer.logger, list):
for logger in self.trainer.logger:
if isinstance(logger, Logger):
for metric_name, metric_value in test_results[0].items():
logger.log_metrics({f"test_{metric_name}": metric_value})
print("Test metrics logged successfully to all loggers.")
else:
print("MLflow logger not found in the trainer's loggers.")

# def log_test_metrics(self, test_results):
# if not self.trainer.logger:
# print("No logger found. Test metrics will not be logged.")
# return

# if isinstance(self.trainer.logger, Logger):
# for metric_name, metric_value in test_results[0].items():
# self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value})
# print("Test metrics logged successfully.")
# elif isinstance(self.trainer.logger, list):
# for logger in self.trainer.logger:
# if isinstance(logger, Logger):
# for metric_name, metric_value in test_results[0].items():
# logger.log_metrics({f"test_{metric_name}": metric_value})
# print("Test metrics logged successfully to all loggers.")
# else:
# print("Unsupported logger type. Test metrics will not be logged.")
print("Unsupported logger type. Test metrics will not be logged.")


def main(args: ArgsType = None) -> None:
Expand Down

0 comments on commit b54ff80

Please sign in to comment.