From c2e4cc89a1958c63695d2bf53a4c3ce09796badb Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Sun, 8 Dec 2024 18:43:26 +0100 Subject: [PATCH 1/3] Added support for PyTorch Lightning in the DDP backend. --- neps/runtime.py | 50 +++++++++++++++++++++++++++++++++++++++- neps/state/filebased.py | 9 ++++++++ neps/state/neps_state.py | 6 +++++ neps/state/protocols.py | 4 ++++ 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/neps/runtime.py b/neps/runtime.py index 71b72f87..81fc70a5 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -26,7 +26,10 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.filebased import ( + create_or_load_filebased_neps_state, + load_filebased_neps_state, +) from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.state.trial import Trial @@ -43,6 +46,24 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +def _is_ddp_and_not_rank_zero() -> bool: + import torch.distributed as dist + + # Check for environment variables typically set by DDP + ddp_env_vars = ["WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"] + rank_env_vars = ["RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK"] + + # Check if PyTorch distributed is initialized + if (dist.is_available() and dist.is_initialized()) or all( + var in os.environ for var in ddp_env_vars + ): + for var in rank_env_vars: + rank = os.environ.get(var) + if rank is not None: + return int(rank) != 0 + return False + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -488,6 +509,26 @@ def run(self) -> None: # noqa: C901, PLR0915 ) +def _launch_ddp_runtime( + *, + evaluation_fn: Callable[..., float | Mapping[str, Any]], + optimization_dir: Path, +) -> None: + neps_state = load_filebased_neps_state(directory=optimization_dir) + + # TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes + # the previous trial gets sampled when we don't want it to. This is a bit of a + # hack to get around that. + prev_trial = None + while True: + current_trial = neps_state.get_current_evaluating_trial() + if current_trial is not None and ( + prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable] + ): + evaluation_fn(**current_trial.config) + prev_trial = current_trial + + # TODO: This should be done directly in `api.run` at some point to make it clearer at an # entryy point how the woerer is set up to run if someone reads the entry point code. def _launch_runtime( # noqa: PLR0913 @@ -506,6 +547,13 @@ def _launch_runtime( # noqa: PLR0913 max_evaluations_for_worker: int | None, pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None, ) -> None: + if _is_ddp_and_not_rank_zero(): + # Do not launch a new worker if we are in a DDP setup and not rank 0 + _launch_ddp_runtime( + evaluation_fn=evaluation_fn, optimization_dir=optimization_dir + ) + return + if overwrite_optimization_dir and optimization_dir.exists(): logger.info( f"Overwriting optimization directory '{optimization_dir}' as" diff --git a/neps/state/filebased.py b/neps/state/filebased.py index cf53c622..ed21e81f 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -209,6 +209,15 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: ] return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2])) + @override + def evaluating(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: + evaluating = [ + (_id, t, trial.metadata.time_sampled) + for (_id, t) in self.all().items() + if (trial := t.synced()).state == Trial.State.EVALUATING + ] + return iter((_id, t) for _id, t, _ in sorted(evaluating, key=lambda x: x[2])) + @dataclass class ReaderWriterTrial(ReaderWriter[Trial, Path]): diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 1ed3f67b..22fda4c6 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -214,6 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | return take(n, _pending_itr) return next(_pending_itr, None) + def get_current_evaluating_trial(self) -> Trial | None: + """Get the current evaluating trial.""" + for _, shared_trial in self._trials.evaluating(): + return shared_trial.synced() + return None + def all_trial_ids(self) -> set[str]: """Get all the trial ids that are known about.""" return self._trials.all_trial_ids() diff --git a/neps/state/protocols.py b/neps/state/protocols.py index 51bff7d3..e7b31302 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -138,6 +138,10 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, K]]]: """ ... + def evaluating(self) -> Iterable[tuple[str, Synced[Trial, K]]]: + """Get all evaluating trials in the repo.""" + ... + @dataclass class VersionedResource(Generic[T, K]): From 3af8969d9eb6a5155fb0b30c7f415d42ebf658c3 Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Wed, 18 Dec 2024 01:52:01 +0100 Subject: [PATCH 2/3] hacky fix for multiple DDP launches --- neps/runtime.py | 30 ++++++++++++++++++++++++------ neps/state/neps_state.py | 11 ++++++----- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 81fc70a5..10823e01 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -521,12 +521,30 @@ def _launch_ddp_runtime( # hack to get around that. prev_trial = None while True: - current_trial = neps_state.get_current_evaluating_trial() - if current_trial is not None and ( - prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable] - ): - evaluation_fn(**current_trial.config) - prev_trial = current_trial + current_eval_trials = neps_state.get_current_evaluating_trials() + # If the worker id on previous trial is the same as the current one, only then + # evaluate it. + + if len(current_eval_trials) > 0: + current_trial = None + if prev_trial is None: + # TODO: This is wrong. we evaluate the first trial in the list + # Instead, we need to check and evaluate the trial that is being + # evaluated by the parent process. + # Currently this only works if the DDP trainings are launched after some + # trials evaluation has begun. + current_trial = current_eval_trials[0] + else: + for trial in current_eval_trials: # type: ignore[unreachable] + if ( + trial.metadata.evaluating_worker_id + == prev_trial.metadata.evaluating_worker_id + ) and (trial.id != prev_trial.id): + current_trial = trial + break + if current_trial: + evaluation_fn(**current_trial.config) + prev_trial = current_trial # TODO: This should be done directly in `api.run` at some point to make it clearer at an diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 22fda4c6..4bb9b510 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -214,11 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | return take(n, _pending_itr) return next(_pending_itr, None) - def get_current_evaluating_trial(self) -> Trial | None: - """Get the current evaluating trial.""" - for _, shared_trial in self._trials.evaluating(): - return shared_trial.synced() - return None + def get_current_evaluating_trials(self) -> list[Trial]: + """Get all the current evaluating trials.""" + _eval_itr = ( + shared_trial.synced() for _, shared_trial in self._trials.evaluating() + ) + return list(_eval_itr) def all_trial_ids(self) -> set[str]: """Get all the trial ids that are known about.""" From 7e7810cbf0575f136f8258c2f8d07ced81cb2f58 Mon Sep 17 00:00:00 2001 From: Gopalji Gaur Date: Sun, 22 Dec 2024 23:27:23 +0100 Subject: [PATCH 3/3] use env var to share config with higher rank workers --- neps/runtime.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 10823e01..48738e5f 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -46,6 +46,9 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" +_DDP_ENV_VAR_NAME = "NEPS_DDP_TRIAL_ID" + + def _is_ddp_and_not_rank_zero() -> bool: import torch.distributed as dist @@ -64,6 +67,11 @@ def _is_ddp_and_not_rank_zero() -> bool: return False +def _set_ddp_env_var(trial_id: str) -> None: + """Sets an environment variable with current trial_id in a DDP setup.""" + os.environ[_DDP_ENV_VAR_NAME] = trial_id + + N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 N_FAILED_TO_SET_TRIAL_STATE = 10 @@ -131,6 +139,7 @@ def _set_global_trial(trial: Trial) -> Iterator[None]: "\n\nThis is most likely a bug and should be reported to NePS!" ) _CURRENTLY_RUNNING_TRIAL_IN_PROCESS = trial + _set_ddp_env_var(trial.id) yield for _key, callback in _TRIAL_END_CALLBACKS.items(): callback(trial) @@ -515,25 +524,25 @@ def _launch_ddp_runtime( optimization_dir: Path, ) -> None: neps_state = load_filebased_neps_state(directory=optimization_dir) - - # TODO: This is a bit of a hack to get the current trial to evaluate. Sometimes - # the previous trial gets sampled when we don't want it to. This is a bit of a - # hack to get around that. prev_trial = None while True: current_eval_trials = neps_state.get_current_evaluating_trials() # If the worker id on previous trial is the same as the current one, only then # evaluate it. - if len(current_eval_trials) > 0: current_trial = None if prev_trial is None: - # TODO: This is wrong. we evaluate the first trial in the list - # Instead, we need to check and evaluate the trial that is being - # evaluated by the parent process. - # Currently this only works if the DDP trainings are launched after some - # trials evaluation has begun. - current_trial = current_eval_trials[0] + # In the beginning, we simply read the current trial from the + # environment variable + if _DDP_ENV_VAR_NAME in os.environ: + current_id = os.getenv(_DDP_ENV_VAR_NAME) + if current_id is None: + raise RuntimeError( + "In a pytorch-lightning DDP setup, the environment variable" + f" '{_DDP_ENV_VAR_NAME}' was not set. This is probably a bug in" + " NePS and should be reported." + ) + current_trial = neps_state.get_trial_by_id(current_id) else: for trial in current_eval_trials: # type: ignore[unreachable] if (