Skip to content

Commit

Permalink
hacky fix for multiple DDP launches
Browse files Browse the repository at this point in the history
  • Loading branch information
Gopalji Gaur committed Dec 18, 2024
1 parent c2e4cc8 commit 3af8969
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
30 changes: 24 additions & 6 deletions neps/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,30 @@ def _launch_ddp_runtime(
# hack to get around that.
prev_trial = None
while True:
current_trial = neps_state.get_current_evaluating_trial()
if current_trial is not None and (
prev_trial is None or current_trial.id != prev_trial.id # type: ignore[unreachable]
):
evaluation_fn(**current_trial.config)
prev_trial = current_trial
current_eval_trials = neps_state.get_current_evaluating_trials()
# If the worker id on previous trial is the same as the current one, only then
# evaluate it.

if len(current_eval_trials) > 0:
current_trial = None
if prev_trial is None:
# TODO: This is wrong. we evaluate the first trial in the list
# Instead, we need to check and evaluate the trial that is being
# evaluated by the parent process.
# Currently this only works if the DDP trainings are launched after some
# trials evaluation has begun.
current_trial = current_eval_trials[0]
else:
for trial in current_eval_trials: # type: ignore[unreachable]
if (
trial.metadata.evaluating_worker_id
== prev_trial.metadata.evaluating_worker_id
) and (trial.id != prev_trial.id):
current_trial = trial
break
if current_trial:
evaluation_fn(**current_trial.config)
prev_trial = current_trial


# TODO: This should be done directly in `api.run` at some point to make it clearer at an
Expand Down
11 changes: 6 additions & 5 deletions neps/state/neps_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] |
return take(n, _pending_itr)
return next(_pending_itr, None)

def get_current_evaluating_trial(self) -> Trial | None:
"""Get the current evaluating trial."""
for _, shared_trial in self._trials.evaluating():
return shared_trial.synced()
return None
def get_current_evaluating_trials(self) -> list[Trial]:
"""Get all the current evaluating trials."""
_eval_itr = (
shared_trial.synced() for _, shared_trial in self._trials.evaluating()
)
return list(_eval_itr)

def all_trial_ids(self) -> set[str]:
"""Get all the trial ids that are known about."""
Expand Down

0 comments on commit 3af8969

Please sign in to comment.