From 81f8b9dae9d0d282eda43310e3fcdd565cc2054d Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:32:46 +0000 Subject: [PATCH] Refactor setting log and tracker defaults --- janus_core/calculations/base.py | 16 ++++---------- janus_core/helpers/utils.py | 35 +++++++++++++++++++++++++++++++ janus_core/training/preprocess.py | 16 ++++---------- janus_core/training/train.py | 16 ++++---------- 4 files changed, 47 insertions(+), 36 deletions(-) diff --git a/janus_core/calculations/base.py b/janus_core/calculations/base.py index 461cac7d..22b5a9e6 100644 --- a/janus_core/calculations/base.py +++ b/janus_core/calculations/base.py @@ -16,7 +16,7 @@ ) from janus_core.helpers.log import config_logger, config_tracker from janus_core.helpers.struct_io import input_structs -from janus_core.helpers.utils import FileNameMixin, none_to_dict +from janus_core.helpers.utils import FileNameMixin, none_to_dict, set_log_tracker class BaseCalculation(FileNameMixin): @@ -155,17 +155,9 @@ def __init__( if not self.model_path and "model_path" in self.calc_kwargs: raise ValueError("`model_path` must be passed explicitly") - if "filename" in log_kwargs: - attach_logger = True - else: - attach_logger = attach_logger if attach_logger else False - - if not attach_logger: - if track_carbon: - raise ValueError("Carbon tracking requires logging to be enabled") - self.track_carbon = False - else: - self.track_carbon = track_carbon if track_carbon is not None else True + attach_logger, self.track_carbon = set_log_tracker( + attach_logger, log_kwargs, track_carbon + ) # Read structures and/or attach calculators # Note: logger not set up so yet so not passed here diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 32ed6043..b1976c32 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -517,3 +517,38 @@ def selector_len(slc: SliceLike | list, selectable_length: int) -> int: if stop is None: stop = selectable_length return len(range(start, stop, step)) + + +def set_log_tracker( + attach_logger: bool, log_kwargs: dict, track_carbon: bool +) -> tuple[bool, bool]: + """ + Set attach_logger and track_carbon default values. + + Parameters + ---------- + attach_logger : bool + Whether to attach a logger. + log_kwargs : dict[str, Any] + Keyword arguments to pass to `config_logger`. + track_carbon : bool + Whether to track carbon emissions of calculation. + + Returns + ------- + tuple[bool, bool] + Default values for attach_logger and track_carbon. + """ + if "filename" in log_kwargs: + attach_logger = True + else: + attach_logger = attach_logger if attach_logger else False + + if not attach_logger: + if track_carbon: + raise ValueError("Carbon tracking requires logging to be enabled") + track_carbon = False + else: + track_carbon = track_carbon if track_carbon is not None else True + + return attach_logger, track_carbon diff --git a/janus_core/training/preprocess.py b/janus_core/training/preprocess.py index 392d2b4b..dd02536f 100644 --- a/janus_core/training/preprocess.py +++ b/janus_core/training/preprocess.py @@ -11,7 +11,7 @@ from janus_core.helpers.janus_types import PathLike from janus_core.helpers.log import config_logger, config_tracker -from janus_core.helpers.utils import check_files_exist, none_to_dict +from janus_core.helpers.utils import check_files_exist, none_to_dict, set_log_tracker def preprocess( @@ -53,17 +53,9 @@ def preprocess( options = yaml.safe_load(file) check_files_exist(options, req_file_keys) - if "filename" in log_kwargs: - attach_logger = True - else: - attach_logger = attach_logger if attach_logger else False - - if not attach_logger: - if track_carbon: - raise ValueError("Carbon tracking requires logging to be enabled") - track_carbon = False - else: - track_carbon = track_carbon if track_carbon is not None else True + attach_logger, track_carbon = set_log_tracker( + attach_logger, log_kwargs, track_carbon + ) # Configure logging if attach_logger: diff --git a/janus_core/training/train.py b/janus_core/training/train.py index 70fbdd92..8abc33f7 100644 --- a/janus_core/training/train.py +++ b/janus_core/training/train.py @@ -11,7 +11,7 @@ from janus_core.helpers.janus_types import PathLike from janus_core.helpers.log import config_logger, config_tracker -from janus_core.helpers.utils import check_files_exist, none_to_dict +from janus_core.helpers.utils import check_files_exist, none_to_dict, set_log_tracker def train( @@ -58,17 +58,9 @@ def train( options = yaml.safe_load(file) check_files_exist(options, req_file_keys) - if "filename" in log_kwargs: - attach_logger = True - else: - attach_logger = attach_logger if attach_logger else False - - if not attach_logger: - if track_carbon: - raise ValueError("Carbon tracking requires logging to be enabled") - track_carbon = False - else: - track_carbon = track_carbon if track_carbon is not None else True + attach_logger, track_carbon = set_log_tracker( + attach_logger, log_kwargs, track_carbon + ) # Configure logging if attach_logger: