diff --git a/geo_deep_learning/train.py b/geo_deep_learning/train.py index 9f899662..3fecf13d 100644 --- a/geo_deep_learning/train.py +++ b/geo_deep_learning/train.py @@ -19,8 +19,7 @@ def after_fit(self): strategy="auto") best_model = self.model.__class__.load_from_checkpoint(best_model_path) test_results = test_trainer.test(model=best_model, dataloaders=self.datamodule.test_dataloader()) - self.log_test_metrics_to_mlflow(test_results) - # self.log_test_metrics(test_results) + self.log_test_metrics(test_results) self.trainer.strategy.barrier() def print_dataset_sizes(self): @@ -33,39 +32,22 @@ def print_dataset_sizes(self): print(f"Number of validation samples: {val_size}") print(f"Number of test samples: {test_size}") - def log_test_metrics_to_mlflow(self, test_results): - # Get the MLflow logger from the trainer - mlf_logger = next((logger for logger in self.trainer.loggers if isinstance(logger, MLFlowLogger)), None) - - if mlf_logger is not None: - # Log each metric from the test results + def log_test_metrics(self, test_results): + if not self.trainer.logger: + print("No logger found. Test metrics will not be logged.") + return + if isinstance(self.trainer.logger, Logger): for metric_name, metric_value in test_results[0].items(): - mlf_logger.experiment.log_metric( - run_id=mlf_logger.run_id, - key=f"test_{metric_name}", - value=metric_value - ) - print("Test metrics logged to MLflow successfully.") + self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value}) + print("Test metrics logged successfully.") + elif isinstance(self.trainer.logger, list): + for logger in self.trainer.logger: + if isinstance(logger, Logger): + for metric_name, metric_value in test_results[0].items(): + logger.log_metrics({f"test_{metric_name}": metric_value}) + print("Test metrics logged successfully to all loggers.") else: - print("MLflow logger not found in the trainer's loggers.") - - # def log_test_metrics(self, test_results): - # if not self.trainer.logger: - # print("No logger found. Test metrics will not be logged.") - # return - - # if isinstance(self.trainer.logger, Logger): - # for metric_name, metric_value in test_results[0].items(): - # self.trainer.logger.log_metrics({f"test_{metric_name}": metric_value}) - # print("Test metrics logged successfully.") - # elif isinstance(self.trainer.logger, list): - # for logger in self.trainer.logger: - # if isinstance(logger, Logger): - # for metric_name, metric_value in test_results[0].items(): - # logger.log_metrics({f"test_{metric_name}": metric_value}) - # print("Test metrics logged successfully to all loggers.") - # else: - # print("Unsupported logger type. Test metrics will not be logged.") + print("Unsupported logger type. Test metrics will not be logged.") def main(args: ArgsType = None) -> None: