From a9611445612803bdac2c3f6775d7eb25be7d000d Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 15:41:57 +0100 Subject: [PATCH 01/56] fix: Local variable doesn't exist --- neps/runtime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neps/runtime.py b/neps/runtime.py index 71b72f87..0a095865 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -379,6 +379,7 @@ def run(self) -> None: # noqa: C901, PLR0915 _error_from_evaluation: Exception | None = None _repeated_fail_get_next_trial_count = 0 + n_failed_set_trial_state = 0 while True: # NOTE: We rely on this function to do logging and raising errors if it should should_stop = self._check_if_should_stop( From 248255215724be7872375d52e41d16a680e52ff4 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 15:48:00 +0100 Subject: [PATCH 02/56] optim: Recompute cost usage at sample, not in report This means that reporting a trial value no longer has to lock the optimizers state, as this will now be computed when sampling a new trial, given we have access to all the trials and the optimizer state at this point. --- neps/state/neps_state.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 1ed3f67b..a2249ed9 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -102,8 +102,21 @@ def sample_trial( for hook in _sample_hooks: optimizer = hook(optimizer) - # NOTE: We don't want optimizers mutating this before serialization - budget = opt_state.budget.clone() if opt_state.budget is not None else None + # NOTE: Re-work this, as the part's that are recomputed do not need to be serialized + budget = opt_state.budget + if budget is not None: + budget = budget.clone() + + # NOTE: All other values of budget are ones that should remain + # constant, there are currently only these two which are dynamic as + # optimization unfold + budget.used_cost_budget = sum( + trial.report.cost + for trial in trials.values() + if trial.report is not None and trial.report.cost is not None + ) + budget.used_evaluations = len(trials) + sampled_config_maybe_new_opt_state = optimizer.ask( trials=trials, budget_info=budget, @@ -166,16 +179,6 @@ def report_trial_evaluation( trial.report = report shared_trial.put(trial) logger.debug("Updated trial '%s' with status '%s'", trial.id, trial.state) - with self._optimizer_state.acquire() as (opt_state, put_opt_state): - # TODO: If an optimizer doesn't use the state, this is a waste of time. - # Update the budget if we have one. - if opt_state.budget is not None: - budget_info = opt_state.budget - - if report.cost is not None: - budget_info.used_cost_budget += report.cost - put_opt_state(opt_state) - if report.err is not None: with self._shared_errors.acquire() as (errs, put_errs): trial_err = ErrDump.SerializableTrialError( From 9df0351079b79c9fe1b21e205bccf5449227d2f5 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 15:53:19 +0100 Subject: [PATCH 03/56] fix(runtime): Allow multiple retries of creating/loading NePSState. --- neps/runtime.py | 47 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 0a095865..ac667a82 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -45,6 +45,7 @@ def _default_worker_name() -> str: N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 N_FAILED_TO_SET_TRIAL_STATE = 10 +N_FAILED_CREATE_LOAD_STATE_ATTEMPTS = 10 Loc = TypeVar("Loc") @@ -514,21 +515,37 @@ def _launch_runtime( # noqa: PLR0913 ) shutil.rmtree(optimization_dir) - neps_state = create_or_load_filebased_neps_state( - directory=optimization_dir, - optimizer_info=OptimizerInfo(optimizer_info), - optimizer_state=OptimizationState( - budget=( - BudgetInfo( - max_cost_budget=max_cost_total, - used_cost_budget=0, - max_evaluations=max_evaluations_total, - used_evaluations=0, - ) - ), - shared_state={}, # TODO: Unused for the time being... - ), - ) + retry_count = 0 + for retry_count in range(N_FAILED_CREATE_LOAD_STATE_ATTEMPTS): + try: + neps_state = create_or_load_filebased_neps_state( + directory=optimization_dir, + optimizer_info=OptimizerInfo(optimizer_info), + optimizer_state=OptimizationState( + budget=( + BudgetInfo( + max_cost_budget=max_cost_total, + used_cost_budget=0, + max_evaluations=max_evaluations_total, + used_evaluations=0, + ) + ), + shared_state={}, # TODO: Unused for the time being... + ), + ) + break + except Exception as e: + time.sleep(0.5) + logger.debug( + "Error while trying to create or load the NePS state. Retrying...", + exc_info=True, + ) + else: + raise RuntimeError( + "Failed to create or load the NePS state after" + f" {retry_count} attempts. Bailing!" + " Please enable debug logging to see the errors that occured." + ) settings = WorkerSettings( on_error=( From 115dc9b73be5ef9524a8d21d81cd576478a780eb Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 16:00:59 +0100 Subject: [PATCH 04/56] style: Fix pre-commit --- neps/runtime.py | 7 +++---- neps/state/neps_state.py | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index ac667a82..1cdc1347 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -515,8 +515,7 @@ def _launch_runtime( # noqa: PLR0913 ) shutil.rmtree(optimization_dir) - retry_count = 0 - for retry_count in range(N_FAILED_CREATE_LOAD_STATE_ATTEMPTS): + for _retry_count in range(N_FAILED_CREATE_LOAD_STATE_ATTEMPTS): try: neps_state = create_or_load_filebased_neps_state( directory=optimization_dir, @@ -534,7 +533,7 @@ def _launch_runtime( # noqa: PLR0913 ), ) break - except Exception as e: + except Exception: # noqa: BLE001 time.sleep(0.5) logger.debug( "Error while trying to create or load the NePS state. Retrying...", @@ -543,7 +542,7 @@ def _launch_runtime( # noqa: PLR0913 else: raise RuntimeError( "Failed to create or load the NePS state after" - f" {retry_count} attempts. Bailing!" + f" {N_FAILED_CREATE_LOAD_STATE_ATTEMPTS} attempts. Bailing!" " Please enable debug logging to see the errors that occured." ) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index a2249ed9..0390c8c0 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -102,7 +102,8 @@ def sample_trial( for hook in _sample_hooks: optimizer = hook(optimizer) - # NOTE: Re-work this, as the part's that are recomputed do not need to be serialized + # NOTE: Re-work this, as the part's that are recomputed + # do not need to be serialized budget = opt_state.budget if budget is not None: budget = budget.clone() From 8b274e61d0e75ee316179ffed2fc94e93933f8ac Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 17:41:45 +0100 Subject: [PATCH 05/56] ux(state): Improve error contents for VersionMistmatchError --- neps/state/filebased.py | 10 ++++++++++ neps/state/protocols.py | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index cf53c622..ab8e4bda 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -214,6 +214,8 @@ def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: class ReaderWriterTrial(ReaderWriter[Trial, Path]): """ReaderWriter for Trial objects.""" + CHEAP_LOCKLESS_READ: ClassVar = True + CONFIG_FILENAME = "config.yaml" METADATA_FILENAME = "metadata.yaml" STATE_FILENAME = "state.txt" @@ -261,6 +263,8 @@ def write(cls, trial: Trial, directory: Path) -> None: class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): """ReaderWriter for SeedSnapshot objects.""" + CHEAP_LOCKLESS_READ: ClassVar = True + # It seems like they're all uint32 but I can't be sure. PY_RNG_STATE_DTYPE: ClassVar = np.int64 @@ -358,6 +362,8 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: class ReaderWriterOptimizerInfo(ReaderWriter[OptimizerInfo, Path]): """ReaderWriter for OptimizerInfo objects.""" + CHEAP_LOCKLESS_READ: ClassVar = True + INFO_FILENAME: ClassVar = "info.yaml" @override @@ -381,6 +387,8 @@ def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: class ReaderWriterOptimizationState(ReaderWriter[OptimizationState, Path]): """ReaderWriter for OptimizationState objects.""" + CHEAP_LOCKLESS_READ: ClassVar = True + STATE_FILE_NAME: ClassVar = "state.yaml" @override @@ -406,6 +414,8 @@ def write(cls, info: OptimizationState, directory: Path) -> None: class ReaderWriterErrDump(ReaderWriter[ErrDump, Path]): """ReaderWriter for shared error lists.""" + CHEAP_LOCKLESS_READ: ClassVar = True + name: str @override diff --git a/neps/state/protocols.py b/neps/state/protocols.py index 51bff7d3..ec0e28c0 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -94,6 +94,14 @@ class ReaderWriter(Protocol[T, Loc_contra]): trials, given some `Path`. """ + CHEAP_LOCKLESS_READ: ClassVar[bool] + """Whether reading the contents of the resource is cheap, cheap enough to be + most likely safe without a lock if outdated information is acceptable. + + This is currently used to help debugging instances of a VersionMismatchError + to see what the current state is and what was attempted to be written. + """ + def read(self, loc: Loc_contra, /) -> T: """Read the resource at the given location.""" ... @@ -278,6 +286,17 @@ def put(self, data: T) -> None: """ current_version = self._versioner.current() if self._version != current_version: + # We will attempt to do a lockless read on the contents of the items, as this + # would allow us to better debug in the error raised below. + if self._reader_writer.CHEAP_LOCKLESS_READ: + current_contents = self._reader_writer.read(self._location) + extra_msg = ( + f"\nThe attempted write was: {data}\n" + f"The current contents are: {current_contents}" + ) + else: + extra_msg = "" + raise self.VersionMismatchError( f"Version mismatch - ours: '{self._version}', remote: '{current_version}'" f" Tried to put data at '{self._location}'. Doing so would overwrite" @@ -285,6 +304,7 @@ def put(self, data: T) -> None: " version of the resource and try again." " The most possible reasons for this error is that a lock was not" " utilized when getting this resource before putting it back." + f"{extra_msg}" ) self._reader_writer.write(data, self._location) From b3f76bd226464b36c60625a76799a00e61c67a38 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 17:44:30 +0100 Subject: [PATCH 06/56] doc: Add comment on why we do not retry reporting --- neps/runtime.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/neps/runtime.py b/neps/runtime.py index 1cdc1347..ea3b12bf 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -476,6 +476,9 @@ def run(self) -> None: # noqa: C901, PLR0915 logger.exception(report.err) _error_from_evaluation = report.err + # We do not retry this, as if some other worker has + # managed to manipulate this trial in the meantime, + # then something has gone wrong self.state.report_trial_evaluation( trial=evaluated_trial, report=report, From 8c59d04d6f9c3da740ef7ae3d8b5517e0338b695 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 17:47:47 +0100 Subject: [PATCH 07/56] fix: Template string in error message --- neps/runtime.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index ea3b12bf..2b27174e 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -440,8 +440,8 @@ def run(self) -> None: # noqa: C901, PLR0915 if n_failed_set_trial_state != 0: if n_failed_set_trial_state >= N_FAILED_TO_SET_TRIAL_STATE: raise WorkerFailedToGetPendingTrialsError( - "Worker '%s' failed to set trial to evaluating %d times in a row." - " Bailing!" + f"Worker {self.worker_id} failed to set trial to evaluating" + f" {N_FAILED_TO_SET_TRIAL_STATE} times in a row. Bailing!" ) continue From f4ef73ba4ef5ef2bb4f959baf80f779db0bc4e34 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 19:12:20 +0100 Subject: [PATCH 08/56] fix: Don't reload configurations after checking `pending()` state --- neps/state/filebased.py | 4 ++-- neps/state/neps_state.py | 4 +--- neps/state/protocols.py | 2 +- neps/state/trial.py | 4 ---- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index ab8e4bda..42fd5090 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -201,9 +201,9 @@ def all(self) -> dict[str, Synced[Trial, Path]]: return {trial_id: self.get_by_id(trial_id) for trial_id in self.all_trial_ids()} @override - def pending(self) -> Iterable[tuple[str, Synced[Trial, Path]]]: + def pending(self) -> Iterable[tuple[str, Trial]]: pending = [ - (_id, t, trial.metadata.time_sampled) + (_id, trial, trial.metadata.time_sampled) for (_id, t) in self.all().items() if (trial := t.synced()).state == Trial.State.PENDING ] diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 0390c8c0..2bbaf0b9 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -211,9 +211,7 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | Returns: The next trial or a list of trials if `n` is not `None`. """ - _pending_itr = ( - shared_trial.synced() for _, shared_trial in self._trials.pending() - ) + _pending_itr = (shared_trial for _, shared_trial in self._trials.pending()) if n is not None: return take(n, _pending_itr) return next(_pending_itr, None) diff --git a/neps/state/protocols.py b/neps/state/protocols.py index ec0e28c0..addb0299 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -137,7 +137,7 @@ def all(self) -> dict[str, Synced[Trial, K]]: """Get all trials in the repo.""" ... - def pending(self) -> Iterable[tuple[str, Synced[Trial, K]]]: + def pending(self) -> Iterable[tuple[str, Trial]]: """Get all pending trials in the repo. !!! note diff --git a/neps/state/trial.py b/neps/state/trial.py index 75ad0664..f89dd67e 100644 --- a/neps/state/trial.py +++ b/neps/state/trial.py @@ -36,10 +36,6 @@ class State(Enum): CORRUPTED = "corrupted" UNKNOWN = "unknown" - def pending(self) -> bool: - """Return True if the trial is pending.""" - return self in (State.PENDING, State.SUBMITTED, State.EVALUATING) - @dataclass class MetaData: From 2b1e69b78eb0cbbc79eb29ee16f3ca67f8cb982b Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 19:13:41 +0100 Subject: [PATCH 09/56] doc: Type fix --- neps/state/protocols.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neps/state/protocols.py b/neps/state/protocols.py index addb0299..b08b6ffa 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -202,7 +202,7 @@ def new( current_version = versioner.current() if current_version is not None: raise VersionedResourceAlreadyExistsError( - f"A versioend resource already already exists at '{location}'" + f"A versioned resource already already exists at '{location}'" f" with version '{current_version}'" ) From 8186289c33b515ba4fa700d878b1ca900b47e67f Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 19:52:16 +0100 Subject: [PATCH 10/56] optim: Reduce IO calls --- neps/state/filebased.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 42fd5090..4a96bab3 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -169,8 +169,29 @@ def put_new( """ config_path = self.directory.absolute().resolve() / f"config_{trial.metadata.id}" if config_path.exists(): + # This shouldn't exist, we load in the trial to see the current state of it + # to try determine wtf is going on for logging purposes. + try: + shared_trial = Synced.load( + location=config_path, + locker=FileLocker( + lock_path=config_path / ".lock", + poll=lock_poll, + timeout=lock_timeout, + ), + versioner=FileVersioner(version_file=config_path / ".version"), + reader_writer=ReaderWriterTrial(), + ) + already_existing_trial = shared_trial._unsynced() + extra_msg = ( + f"The existing trial is the following: {already_existing_trial}" + ) + except Exception: # noqa: BLE001 + extra_msg = "Failed to load the existing trial to provide more info." + raise TrialRepo.TrialAlreadyExistsError( f"Trial '{trial.metadata.id}' already exists as '{config_path}'." + f"\n{extra_msg}" ) # HACK: We do this here as there is no way to know where a Trial will @@ -290,18 +311,19 @@ def read(cls, directory: Path) -> SeedSnapshot: np_rng_state = np.fromfile(np_rng_path, dtype=np.uint32) seed_info = deserialize(seedinfo_path) - torch_exists = torch_rng_path.exists() or torch_cuda_rng_path.exists() + torch_rng_path_exists = torch_rng_path.exists() + torch_cuda_rng_path_exists = torch_cuda_rng_path.exists() # By specifying `weights_only=True`, it disables arbitrary object loading torch_rng_state = None torch_cuda_rng = None - if torch_exists: + if torch_rng_path_exists or torch_cuda_rng_path_exists: import torch - if torch_rng_path.exists(): + if torch_rng_path_exists: torch_rng_state = torch.load(torch_rng_path, weights_only=True) - if torch_cuda_rng_path.exists(): + if torch_cuda_rng_path_exists: # By specifying `weights_only=True`, it disables arbitrary object loading torch_cuda_rng = torch.load(torch_cuda_rng_path, weights_only=True) From b2c9eac72c7b78ded2200a0095fa971373756154 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Thu, 28 Nov 2024 19:52:29 +0100 Subject: [PATCH 11/56] ux: Improve logging --- neps/runtime.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 2b27174e..ea4ac7bc 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -397,7 +397,9 @@ def run(self) -> None: # noqa: C901, PLR0915 except Exception as e: _repeated_fail_get_next_trial_count += 1 logger.debug( - "Error while trying to get the next trial to evaluate.", exc_info=True + "Worker '%s': Error while trying to get the next trial to evaluate.", + self.worker_id, + exc_info=True, ) # NOTE: This is to prevent any infinite loops if we can't get a trial @@ -406,8 +408,9 @@ def run(self) -> None: # noqa: C901, PLR0915 >= N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR ): raise WorkerFailedToGetPendingTrialsError( - "Worker '%s' failed to get pending trials %d times in a row." - " Bailing!" + f"Worker {self.worker_id} failed to get pending trials" + f" {N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR} times in" + " a row. Bailing!" ) from e continue @@ -423,15 +426,21 @@ def run(self) -> None: # noqa: C901, PLR0915 except VersionMismatchError: n_failed_set_trial_state += 1 logger.debug( - f"Another worker has managed to change trial '{trial_to_eval.id}'" - " to evaluate and put back into state. This is fine and likely means" - " the other worker is evaluating it.", + "Another worker has managed to change trial '%s'" + " while this worker '%s' was trying to set it to" + " evaluating. This is fine and likely means the other worker is" + " evaluating it, this worker will attempt to sample new trial.", + trial_to_eval.id, + self.worker_id, exc_info=True, ) except Exception: n_failed_set_trial_state += 1 logger.error( - f"Error trying to set trial '{trial_to_eval.id}' to evaluating.", + "Unexpected error from worker '%s' trying to set trial" + " '%' to evaluating.", + self.worker_id, + trial_to_eval.id, exc_info=True, ) From cf8469f79cddcff574731c92ce6efd5c28aaf2ca Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:21:18 +0100 Subject: [PATCH 12/56] fix: Bump retries on getting next task to 10 --- neps/runtime.py | 2 +- neps/state/neps_state.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/neps/runtime.py b/neps/runtime.py index ea4ac7bc..f3ae8453 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -43,7 +43,7 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" -N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 0 +N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 10 N_FAILED_TO_SET_TRIAL_STATE = 10 N_FAILED_CREATE_LOAD_STATE_ATTEMPTS = 10 diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 2bbaf0b9..145c998e 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -89,6 +89,12 @@ def sample_trial( ), self._seed_state.acquire() as (seed_state, put_seed_state), ): + # NOTE: We make the assumption that as we have acquired the optimizer + # state, there is not possibility of another trial being created between + # the time we read in the trials below and `ask()`ing for the next trials + # from the optimizer. If so, that means there is another source of trial + # generation that occurs outside of this function and outside the scope + # of acquiring the optimizer_state lock. trials: dict[str, Trial] = {} for trial_id, shared_trial in self._trials.all().items(): trial = shared_trial.synced() From 801f76413e053cab956a6d9b21448c6dad1944af Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:22:22 +0100 Subject: [PATCH 13/56] fix: Add time delays before retries --- neps/runtime.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/neps/runtime.py b/neps/runtime.py index f3ae8453..42c8b607 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -401,6 +401,7 @@ def run(self) -> None: # noqa: C901, PLR0915 self.worker_id, exc_info=True, ) + time.sleep(1) # Help stagger retries # NOTE: This is to prevent any infinite loops if we can't get a trial if ( @@ -434,6 +435,7 @@ def run(self) -> None: # noqa: C901, PLR0915 self.worker_id, exc_info=True, ) + time.sleep(1) # Help stagger retries except Exception: n_failed_set_trial_state += 1 logger.error( @@ -443,6 +445,7 @@ def run(self) -> None: # noqa: C901, PLR0915 trial_to_eval.id, exc_info=True, ) + time.sleep(1) # Help stagger retries # NOTE: This is to prevent infinite looping if it somehow keeps getting # the same trial and can't set it to evaluating. From 2b5fc1187bc555a61fa5eece6371f553efcd1cdc Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:27:29 +0100 Subject: [PATCH 14/56] refactor: Fixup name of error --- neps/exceptions.py | 4 ++-- neps/state/filebased.py | 1 + neps/state/protocols.py | 4 ++-- neps/utils/cli.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/neps/exceptions.py b/neps/exceptions.py index 7054d7c6..bcfe198f 100644 --- a/neps/exceptions.py +++ b/neps/exceptions.py @@ -27,7 +27,7 @@ class VersionedResourceRemovedError(NePSError): """ -class VersionedResourceDoesNotExistsError(NePSError): +class VersionedResourceDoesNotExistError(NePSError): """Raised when a versioned resource does not exist at a location.""" @@ -39,7 +39,7 @@ class TrialAlreadyExistsError(VersionedResourceAlreadyExistsError): """Raised when a trial already exists in the store.""" -class TrialNotFoundError(VersionedResourceDoesNotExistsError): +class TrialNotFoundError(VersionedResourceDoesNotExistError): """Raised when a trial already exists in the store.""" diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 4a96bab3..c4c838e0 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -191,6 +191,7 @@ def put_new( raise TrialRepo.TrialAlreadyExistsError( f"Trial '{trial.metadata.id}' already exists as '{config_path}'." + f" Tried to put in the trial: {trial}." f"\n{extra_msg}" ) diff --git a/neps/state/protocols.py b/neps/state/protocols.py index b08b6ffa..3185f82c 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -18,7 +18,7 @@ TrialAlreadyExistsError, TrialNotFoundError, VersionedResourceAlreadyExistsError, - VersionedResourceDoesNotExistsError, + VersionedResourceDoesNotExistError, VersionedResourceRemovedError, VersionMismatchError, ) @@ -160,7 +160,7 @@ class VersionedResource(Generic[T, K]): """ VersionMismatchError: ClassVar = VersionMismatchError - VersionedResourceDoesNotExistsError: ClassVar = VersionedResourceDoesNotExistsError + VersionedResourceDoesNotExistsError: ClassVar = VersionedResourceDoesNotExistError VersionedResourceAlreadyExistsError: ClassVar = VersionedResourceAlreadyExistsError VersionedResourceRemovedError: ClassVar = VersionedResourceRemovedError diff --git a/neps/utils/cli.py b/neps/utils/cli.py index 1ad92d4a..cde70357 100644 --- a/neps/utils/cli.py +++ b/neps/utils/cli.py @@ -46,7 +46,7 @@ ) from neps.state.neps_state import NePSState from neps.state.trial import Trial -from neps.exceptions import VersionedResourceDoesNotExistsError, TrialNotFoundError +from neps.exceptions import VersionedResourceDoesNotExistError, TrialNotFoundError from neps.status.status import get_summary_dict from neps.api import _run_args from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo @@ -666,7 +666,7 @@ def load_neps_state(directory_path: Path) -> Optional[NePSState[Path]]: """Load the NePS state with error handling.""" try: return load_filebased_neps_state(directory_path) - except VersionedResourceDoesNotExistsError: + except VersionedResourceDoesNotExistError: print(f"Error: No NePS state found in the directory '{directory_path}'.") print("Ensure that the NePS run has been initialized correctly.") except Exception as e: From 906fd4f310d733ffe45185fd6a9980419164c823 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:30:09 +0100 Subject: [PATCH 15/56] fix: Add retry to checking if worker should stop --- neps/runtime.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 42c8b607..ae226675 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -383,13 +383,24 @@ def run(self) -> None: # noqa: C901, PLR0915 n_failed_set_trial_state = 0 while True: # NOTE: We rely on this function to do logging and raising errors if it should - should_stop = self._check_if_should_stop( - time_monotonic_start=_time_monotonic_start, - error_from_this_worker=_error_from_evaluation, - ) - if should_stop is not False: - logger.info(should_stop) - break + try: + should_stop = self._check_if_should_stop( + time_monotonic_start=_time_monotonic_start, + error_from_this_worker=_error_from_evaluation, + ) + if should_stop is not False: + logger.info(should_stop) + break + except WorkerRaiseError as e: + raise e + except Exception: + logger.error( + "Unexpected error from worker '%s' while checking if it should stop.", + self.worker_id, + exc_info=True, + ) + time.sleep(1) # Help stagger retries + continue try: trial_to_eval = self._get_next_trial_from_state() From 9c88668370e05121283e74b6a8a2ca742977c583 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:38:43 +0100 Subject: [PATCH 16/56] optim: Reduce object creation overhead --- neps/state/filebased.py | 109 +++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 53 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index c4c838e0..aa90ab61 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -25,7 +25,7 @@ from contextlib import contextmanager from dataclasses import asdict, dataclass, field from pathlib import Path -from typing import ClassVar, TypeVar +from typing import ClassVar, Final, TypeVar from typing_extensions import override from uuid import uuid4 @@ -83,6 +83,58 @@ def bump(self) -> str: return sha +@dataclass +class ReaderWriterTrial(ReaderWriter[Trial, Path]): + """ReaderWriter for Trial objects.""" + + CHEAP_LOCKLESS_READ: ClassVar = True + + CONFIG_FILENAME = "config.yaml" + METADATA_FILENAME = "metadata.yaml" + STATE_FILENAME = "state.txt" + REPORT_FILENAME = "report.yaml" + PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" + + @override + @classmethod + def read(cls, directory: Path) -> Trial: + config_path = directory / cls.CONFIG_FILENAME + metadata_path = directory / cls.METADATA_FILENAME + state_path = directory / cls.STATE_FILENAME + report_path = directory / cls.REPORT_FILENAME + + return Trial( + config=deserialize(config_path), + metadata=Trial.MetaData(**deserialize(metadata_path)), + state=Trial.State(state_path.read_text(encoding="utf-8").strip()), + report=( + Trial.Report(**deserialize(report_path)) if report_path.exists() else None + ), + ) + + @override + @classmethod + def write(cls, trial: Trial, directory: Path) -> None: + config_path = directory / cls.CONFIG_FILENAME + metadata_path = directory / cls.METADATA_FILENAME + state_path = directory / cls.STATE_FILENAME + + serialize(trial.config, config_path) + serialize(asdict(trial.metadata), metadata_path) + state_path.write_text(trial.state.value, encoding="utf-8") + + if trial.metadata.previous_trial_id is not None: + previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME + previous_trial_path.write_text(trial.metadata.previous_trial_id) + + if trial.report is not None: + report_path = directory / cls.REPORT_FILENAME + serialize(asdict(trial.report), report_path) + + +_StaticReaderWriterTrial: Final = ReaderWriterTrial() + + @dataclass class TrialRepoInDirectory(TrialRepo[Path]): """A repository of Trials that are stored in a directory.""" @@ -140,7 +192,7 @@ def get_by_id( timeout=lock_timeout, ), versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=ReaderWriterTrial(), + reader_writer=_StaticReaderWriterTrial, ) self._cache[trial_id] = trial return trial @@ -180,7 +232,7 @@ def put_new( timeout=lock_timeout, ), versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=ReaderWriterTrial(), + reader_writer=_StaticReaderWriterTrial, ) already_existing_trial = shared_trial._unsynced() extra_msg = ( @@ -207,7 +259,7 @@ def put_new( timeout=lock_timeout, ), versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=ReaderWriterTrial(), + reader_writer=_StaticReaderWriterTrial, ) self._cache[trial.metadata.id] = shared_trial return shared_trial @@ -232,55 +284,6 @@ def pending(self) -> Iterable[tuple[str, Trial]]: return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2])) -@dataclass -class ReaderWriterTrial(ReaderWriter[Trial, Path]): - """ReaderWriter for Trial objects.""" - - CHEAP_LOCKLESS_READ: ClassVar = True - - CONFIG_FILENAME = "config.yaml" - METADATA_FILENAME = "metadata.yaml" - STATE_FILENAME = "state.txt" - REPORT_FILENAME = "report.yaml" - PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" - - @override - @classmethod - def read(cls, directory: Path) -> Trial: - config_path = directory / cls.CONFIG_FILENAME - metadata_path = directory / cls.METADATA_FILENAME - state_path = directory / cls.STATE_FILENAME - report_path = directory / cls.REPORT_FILENAME - - return Trial( - config=deserialize(config_path), - metadata=Trial.MetaData(**deserialize(metadata_path)), - state=Trial.State(state_path.read_text(encoding="utf-8").strip()), - report=( - Trial.Report(**deserialize(report_path)) if report_path.exists() else None - ), - ) - - @override - @classmethod - def write(cls, trial: Trial, directory: Path) -> None: - config_path = directory / cls.CONFIG_FILENAME - metadata_path = directory / cls.METADATA_FILENAME - state_path = directory / cls.STATE_FILENAME - - serialize(trial.config, config_path) - serialize(asdict(trial.metadata), metadata_path) - state_path.write_text(trial.state.value, encoding="utf-8") - - if trial.metadata.previous_trial_id is not None: - previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME - previous_trial_path.write_text(trial.metadata.previous_trial_id) - - if trial.report is not None: - report_path = directory / cls.REPORT_FILENAME - serialize(asdict(trial.report), report_path) - - @dataclass class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): """ReaderWriter for SeedSnapshot objects.""" From f0f3f641d5be40b62333c2bb050afadf7038fadb Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 11:51:04 +0100 Subject: [PATCH 17/56] optim: Remove use of `set`, cheaper string op --- neps/state/filebased.py | 10 ++++++---- neps/state/neps_state.py | 19 ++++++++----------- neps/state/protocols.py | 2 +- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index aa90ab61..7cbf8efe 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -134,6 +134,8 @@ def write(cls, trial: Trial, directory: Path) -> None: _StaticReaderWriterTrial: Final = ReaderWriterTrial() +CONFIG_PREFIX_LEN: Final = len("config_") + @dataclass class TrialRepoInDirectory(TrialRepo[Path]): @@ -143,13 +145,13 @@ class TrialRepoInDirectory(TrialRepo[Path]): _cache: dict[str, Synced[Trial, Path]] = field(default_factory=dict) @override - def all_trial_ids(self) -> set[str]: + def all_trial_ids(self) -> list[str]: """List all the trial ids in this trial Repo.""" - return { - config_path.name.replace("config_", "") + return [ + config_path.name[CONFIG_PREFIX_LEN:] for config_path in self.directory.iterdir() if config_path.name.startswith("config_") and config_path.is_dir() - } + ] @override def get_by_id( diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 145c998e..95d9dbd2 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -83,10 +83,7 @@ def sample_trial( The new trial. """ with ( - self._optimizer_state.acquire() as ( - opt_state, - put_opt, - ), + self._optimizer_state.acquire() as (opt_state, put_opt), self._seed_state.acquire() as (seed_state, put_seed_state), ): # NOTE: We make the assumption that as we have acquired the optimizer @@ -95,10 +92,10 @@ def sample_trial( # from the optimizer. If so, that means there is another source of trial # generation that occurs outside of this function and outside the scope # of acquiring the optimizer_state lock. - trials: dict[str, Trial] = {} - for trial_id, shared_trial in self._trials.all().items(): - trial = shared_trial.synced() - trials[trial_id] = trial + trials: dict[str, Trial] = { + trial_id: shared_trial.synced() + for trial_id, shared_trial in list(self._trials.all().items()) + } seed_state.set_as_global_seed_state() @@ -147,14 +144,14 @@ def sample_trial( trial = Trial.new( trial_id=sampled_config.id, - location="", # HACK: This will be set by the `TrialRepo` + location="", # HACK: This will be set by the `TrialRepo` in `put_new` config=sampled_config.config, previous_trial=sampled_config.previous_config_id, previous_trial_location=previous_trial_location, time_sampled=time.time(), worker_id=worker_id, ) - shared_trial = self._trials.put_new(trial) + self._trials.put_new(trial) seed_state.recapture() put_seed_state(seed_state) put_opt( @@ -222,7 +219,7 @@ def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | return take(n, _pending_itr) return next(_pending_itr, None) - def all_trial_ids(self) -> set[str]: + def all_trial_ids(self) -> list[str]: """Get all the trial ids that are known about.""" return self._trials.all_trial_ids() diff --git a/neps/state/protocols.py b/neps/state/protocols.py index 3185f82c..7bbe7a9a 100644 --- a/neps/state/protocols.py +++ b/neps/state/protocols.py @@ -121,7 +121,7 @@ class TrialRepo(Protocol[K]): TrialAlreadyExistsError: ClassVar = TrialAlreadyExistsError TrialNotFoundError: ClassVar = TrialNotFoundError - def all_trial_ids(self) -> set[str]: + def all_trial_ids(self) -> list[str]: """List all the trial ids in this trial Repo.""" ... From 8bf505139bd94dff17d2ea79c832b0cecc7717f0 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 12:00:50 +0100 Subject: [PATCH 18/56] feat: Allow setting of retries from ENV vars --- neps/env.py | 15 +++++++++++++++ neps/runtime.py | 24 +++++++++++------------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/neps/env.py b/neps/env.py index f7ea4167..ce011885 100644 --- a/neps/env.py +++ b/neps/env.py @@ -28,6 +28,21 @@ def is_nullable(e: str) -> bool: return e.lower() in ("none", "n", "null") +MAX_RETRIES_GET_NEXT_TRIAL = get_env( + "NEPS_MAX_RETRIES_GET_NEXT_TRIAL", + parse=int, + default=10, +) +MAX_RETRIES_SET_EVALUATING = get_env( + "NEPS_MAX_RETRIES_SET_EVALUATING", + parse=int, + default=10, +) +MAX_RETRIES_CREATE_LOAD_STATE = get_env( + "NEPS_MAX_RETRIES_CREATE_LOAD_STATE", + parse=int, + default=10, +) TRIAL_FILELOCK_POLL = get_env( "NEPS_TRIAL_FILELOCK_POLL", parse=float, diff --git a/neps/runtime.py b/neps/runtime.py index ae226675..c6809296 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -19,6 +19,11 @@ TypeVar, ) +from neps.env import ( + MAX_RETRIES_CREATE_LOAD_STATE, + MAX_RETRIES_GET_NEXT_TRIAL, + MAX_RETRIES_SET_EVALUATING, +) from neps.exceptions import ( NePSError, VersionMismatchError, @@ -43,10 +48,6 @@ def _default_worker_name() -> str: return f"{os.getpid()}-{isoformat}" -N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR = 10 -N_FAILED_TO_SET_TRIAL_STATE = 10 -N_FAILED_CREATE_LOAD_STATE_ATTEMPTS = 10 - Loc = TypeVar("Loc") # NOTE: As each NEPS process is only ever evaluating a single trial, this global can @@ -415,13 +416,10 @@ def run(self) -> None: # noqa: C901, PLR0915 time.sleep(1) # Help stagger retries # NOTE: This is to prevent any infinite loops if we can't get a trial - if ( - _repeated_fail_get_next_trial_count - >= N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR - ): + if _repeated_fail_get_next_trial_count >= MAX_RETRIES_GET_NEXT_TRIAL: raise WorkerFailedToGetPendingTrialsError( f"Worker {self.worker_id} failed to get pending trials" - f" {N_FAILED_GET_NEXT_PENDING_ATTEMPTS_BEFORE_ERROR} times in" + f" {MAX_RETRIES_GET_NEXT_TRIAL} times in" " a row. Bailing!" ) from e @@ -461,10 +459,10 @@ def run(self) -> None: # noqa: C901, PLR0915 # NOTE: This is to prevent infinite looping if it somehow keeps getting # the same trial and can't set it to evaluating. if n_failed_set_trial_state != 0: - if n_failed_set_trial_state >= N_FAILED_TO_SET_TRIAL_STATE: + if n_failed_set_trial_state >= MAX_RETRIES_SET_EVALUATING: raise WorkerFailedToGetPendingTrialsError( f"Worker {self.worker_id} failed to set trial to evaluating" - f" {N_FAILED_TO_SET_TRIAL_STATE} times in a row. Bailing!" + f" {MAX_RETRIES_SET_EVALUATING} times in a row. Bailing!" ) continue @@ -541,7 +539,7 @@ def _launch_runtime( # noqa: PLR0913 ) shutil.rmtree(optimization_dir) - for _retry_count in range(N_FAILED_CREATE_LOAD_STATE_ATTEMPTS): + for _retry_count in range(MAX_RETRIES_CREATE_LOAD_STATE): try: neps_state = create_or_load_filebased_neps_state( directory=optimization_dir, @@ -568,7 +566,7 @@ def _launch_runtime( # noqa: PLR0913 else: raise RuntimeError( "Failed to create or load the NePS state after" - f" {N_FAILED_CREATE_LOAD_STATE_ATTEMPTS} attempts. Bailing!" + f" {MAX_RETRIES_CREATE_LOAD_STATE} attempts. Bailing!" " Please enable debug logging to see the errors that occured." ) From 62e130bcd78359e3e0f0e66a261bf161367513cc Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 12:01:03 +0100 Subject: [PATCH 19/56] doc: Extra logging on issue w.r.t. trial already existing --- neps/state/neps_state.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 95d9dbd2..3d4f3186 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -18,6 +18,7 @@ from more_itertools import take +from neps.exceptions import TrialAlreadyExistsError from neps.state.err_dump import ErrDump from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.trial import Trial @@ -151,7 +152,29 @@ def sample_trial( time_sampled=time.time(), worker_id=worker_id, ) - self._trials.put_new(trial) + try: + self._trials.put_new(trial) + except TrialAlreadyExistsError as e: + if sampled_config.id in trials: + logger.warning( + "The new sampled trial was given an id of '%s', yet this already" + " exists in the loaded in trials given to the optimizer. This" + " indicates a bug with the optimizers allocation of ids.", + sampled_config.id, + ) + else: + logger.warning( + "The new sampled trial was given an id of '%s', which is not one" + " that was loaded in by the optimizer. This indicates that" + " configuration '%s' was put on disk during the time that this" + " worker had the optimizer state lock OR that after obtaining the" + " optimizer state lock, somehow this configuration failed to be" + " loaded in and passed to the optimizer.", + sampled_config.id, + sampled_config.id, + ) + raise e + seed_state.recapture() put_seed_state(seed_state) put_opt( From 5b54ce57fb57fe82015f117506efc56632cf4b71 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 12:09:55 +0100 Subject: [PATCH 20/56] test: Fixup expectation on output of `get_all_trials()` --- tests/test_state/test_filebased_neps_state.py | 2 +- tests/test_state/test_neps_state.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_state/test_filebased_neps_state.py b/tests/test_state/test_filebased_neps_state.py index 8b2d4eb5..02f5a52c 100644 --- a/tests/test_state/test_filebased_neps_state.py +++ b/tests/test_state/test_filebased_neps_state.py @@ -45,7 +45,7 @@ def test_create_with_new_filebased_neps_state( ) assert neps_state.optimizer_info() == optimizer_info assert neps_state.optimizer_state() == optimizer_state - assert neps_state.all_trial_ids() == set() + assert neps_state.all_trial_ids() == [] assert neps_state.get_all_trials() == {} assert neps_state.get_errors() == ErrDump(errs=[]) assert neps_state.get_next_pending_trial() is None diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index 51773fdb..c64cb64e 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -175,7 +175,7 @@ def test_sample_trial( assert neps_state.get_all_trials() == {} assert neps_state.get_next_pending_trial() is None assert neps_state.get_next_pending_trial(n=10) == [] - assert neps_state.all_trial_ids() == set() + assert neps_state.all_trial_ids() == [] trial1 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): @@ -189,7 +189,7 @@ def test_sample_trial( assert neps_state.get_all_trials() == {trial1.id: trial1} assert neps_state.get_next_pending_trial() == trial1 assert neps_state.get_next_pending_trial(n=10) == [trial1] - assert neps_state.all_trial_ids() == {trial1.id} + assert neps_state.all_trial_ids() == [trial1.id] trial2 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): @@ -201,4 +201,4 @@ def test_sample_trial( assert neps_state.get_all_trials() == {trial1.id: trial1, trial2.id: trial2} assert neps_state.get_next_pending_trial() == trial1 assert neps_state.get_next_pending_trial(n=10) == [trial1, trial2] - assert neps_state.all_trial_ids() == {trial1.id, trial2.id} + assert sorted(neps_state.all_trial_ids()) == [trial1.id, trial2.id] From ee8e90eb6c05354f642dff93f08b20f59bcca2f8 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 12:13:41 +0100 Subject: [PATCH 21/56] fix: Add retry counter to checking if should stop --- neps/env.py | 5 +++++ neps/runtime.py | 16 ++++++++++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/neps/env.py b/neps/env.py index ce011885..58ea0f7f 100644 --- a/neps/env.py +++ b/neps/env.py @@ -43,6 +43,11 @@ def is_nullable(e: str) -> bool: parse=int, default=10, ) +MAX_RETRIES_WORKER_CHECK_SHOULD_STOP = get_env( + "NEPS_MAX_RETRIES_WORKER_CHECK_SHOULD_STOP", + parse=int, + default=3, +) TRIAL_FILELOCK_POLL = get_env( "NEPS_TRIAL_FILELOCK_POLL", parse=float, diff --git a/neps/runtime.py b/neps/runtime.py index c6809296..4599247f 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -23,6 +23,7 @@ MAX_RETRIES_CREATE_LOAD_STATE, MAX_RETRIES_GET_NEXT_TRIAL, MAX_RETRIES_SET_EVALUATING, + MAX_RETRIES_WORKER_CHECK_SHOULD_STOP, ) from neps.exceptions import ( NePSError, @@ -367,7 +368,7 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False - def run(self) -> None: # noqa: C901, PLR0915 + def run(self) -> None: # noqa: C901, PLR0915, PLR0912 """Run the worker. Will keep running until one of the criterion defined by the `WorkerSettings` @@ -382,6 +383,7 @@ def run(self) -> None: # noqa: C901, PLR0915 _repeated_fail_get_next_trial_count = 0 n_failed_set_trial_state = 0 + n_repeated_failed_check_should_stop = 0 while True: # NOTE: We rely on this function to do logging and raising errors if it should try: @@ -394,7 +396,17 @@ def run(self) -> None: # noqa: C901, PLR0915 break except WorkerRaiseError as e: raise e - except Exception: + except Exception as e: + n_repeated_failed_check_should_stop += 1 + if ( + n_repeated_failed_check_should_stop + >= MAX_RETRIES_WORKER_CHECK_SHOULD_STOP + ): + raise WorkerRaiseError( + f"Worker {self.worker_id} failed to check if it should stop" + f" {MAX_RETRIES_WORKER_CHECK_SHOULD_STOP} times in a row. Bailing" + ) from e + logger.error( "Unexpected error from worker '%s' while checking if it should stop.", self.worker_id, From d075b19d467df3a8160ef68bfd6f9fc134752824 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 12:53:54 +0100 Subject: [PATCH 22/56] doc: Remove spammy lock acquisition debug logs --- neps/state/filebased.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 7cbf8efe..9394d0e1 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -502,7 +502,6 @@ def lock( ) -> Iterator[None]: self.lock_path.parent.mkdir(parents=True, exist_ok=True) self.lock_path.touch(exist_ok=True) - logger.debug("Acquiring lock on %s", self.lock_path) try: with pl.Lock( self.lock_path, @@ -522,7 +521,6 @@ def lock( " environment variables to increase the timeout:" f"\n\n{pprint.pformat(ENV_VARS_USED)}" ) from e - logger.debug("Released lock on %s", self.lock_path) def load_filebased_neps_state(directory: Path) -> NePSState[Path]: From c0b5b7754a98017b23d33e6081a321db1e22cccd Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 15:49:06 +0100 Subject: [PATCH 23/56] fix: Better check to not overdo sampling --- neps/runtime.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 4599247f..a78fdf08 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -35,11 +35,11 @@ from neps.state.filebased import create_or_load_filebased_neps_state from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings -from neps.state.trial import Trial if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer from neps.state.neps_state import NePSState + from neps.state.trial import Trial logger = logging.getLogger(__name__) @@ -321,13 +321,12 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 trials = self.state.get_all_trials() if self.settings.max_evaluations_total is not None: if self.settings.include_in_progress_evaluations_towards_maximum: - count = sum( - 1 - for _, trial in trials.items() - if trial.report is not None - or trial.state in (Trial.State.EVALUATING, Trial.State.SUBMITTED) - ) + # NOTE: We can just use the sum of trials in this case as they + # either have a report, are pending or being evaluated. There + # are also crashed and unknown states which we include into this. + count = len(trials) else: + # This indicates they have completed. count = sum(1 for _, trial in trials.items() if trial.report is not None) if count >= self.settings.max_evaluations_total: From c7592c1dd869c7e855d5f8a0b21581c14254a374 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 16:46:07 +0100 Subject: [PATCH 24/56] fix: Add timeout for locking `post_run_csv` --- neps/status/status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neps/status/status.py b/neps/status/status.py index 660c53f3..edc365dc 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -299,7 +299,7 @@ def _save_data_to_csv( config_data_df: The DataFrame containing configuration data. run_data_df: The DataFrame containing additional run data. """ - with locker(poll=2): + with locker(poll=2, timeout=120): try: pending_configs = run_data_df.loc["num_pending_configs", "value"] pending_configs_with_worker = run_data_df.loc[ From 0ff3d41cf512bd77c3ee4b0d126a8640c3de26b1 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Fri, 29 Nov 2024 16:47:39 +0100 Subject: [PATCH 25/56] fix: Add highly generous timeout for post_run_csv --- neps/status/status.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neps/status/status.py b/neps/status/status.py index edc365dc..bb68e50d 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -299,7 +299,7 @@ def _save_data_to_csv( config_data_df: The DataFrame containing configuration data. run_data_df: The DataFrame containing additional run data. """ - with locker(poll=2, timeout=120): + with locker(poll=2, timeout=600): try: pending_configs = run_data_df.loc["num_pending_configs", "value"] pending_configs_with_worker = run_data_df.loc[ From 20b492a77999481188941208af7c6325cb9cdaf0 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 14:29:28 +0100 Subject: [PATCH 26/56] fix: Switch to `lockf` for file-locking on linux --- neps/env.py | 7 +++++++ neps/runtime.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/neps/env.py b/neps/env.py index 58ea0f7f..cd37d4df 100644 --- a/neps/env.py +++ b/neps/env.py @@ -28,6 +28,13 @@ def is_nullable(e: str) -> bool: return e.lower() in ("none", "n", "null") +LINUX_FILELOCK_FUNCTION = get_env( + "NEPS_LINUX_FILELOCK_FUNCTION", + parse=str, + default="lockf", +) +assert LINUX_FILELOCK_FUNCTION in ("lockf", "flock") + MAX_RETRIES_GET_NEXT_TRIAL = get_env( "NEPS_MAX_RETRIES_GET_NEXT_TRIAL", parse=int, diff --git a/neps/runtime.py b/neps/runtime.py index a78fdf08..fb3e1e9b 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -20,6 +20,7 @@ ) from neps.env import ( + LINUX_FILELOCK_FUNCTION, MAX_RETRIES_CREATE_LOAD_STATE, MAX_RETRIES_GET_NEXT_TRIAL, MAX_RETRIES_SET_EVALUATING, @@ -606,6 +607,14 @@ def _launch_runtime( # noqa: PLR0913 max_cost_for_worker=None, # TODO: User can't specify yet ) + # HACK: Due to nfs file-systems, locking with the default `flock()` is not reliable. + # Hence, we overwrite `portalockers` lock call to use `lockf()` instead. + # This is commeneted in their source code that this is an option to use, however + # it's not directly advertised as a parameter/env variable or otherwise. + import portalocker.portalocker as portalocker_lock_module + + setattr(portalocker_lock_module, "LOCKER", LINUX_FILELOCK_FUNCTION) + worker = DefaultWorker.new( state=neps_state, optimizer=optimizer, From 9d9651d3c1e346ea7cd80bce8e4872661db3a697 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 18:13:54 +0100 Subject: [PATCH 27/56] refactor: Favour larger, longer locks with lockf --- neps/env.py | 37 +- neps/plot/tensorboard_eval.py | 2 +- neps/runtime.py | 148 ++-- neps/state/__init__.py | 12 - neps/state/filebased.py | 454 +---------- neps/state/neps_state.py | 710 +++++++++++++----- neps/state/protocols.py | 577 -------------- neps/status/status.py | 17 +- neps/utils/_locker.py | 62 -- neps/utils/cli.py | 42 +- neps/utils/common.py | 4 +- .../test_default_report_values.py | 25 +- .../test_error_handling_strategies.py | 26 +- tests/test_runtime/test_stopping_criterion.py | 92 +-- tests/test_state/test_filebased_neps_state.py | 45 +- tests/test_state/test_neps_state.py | 32 +- tests/test_state/test_synced.py | 429 ----------- 17 files changed, 792 insertions(+), 1922 deletions(-) delete mode 100644 neps/state/protocols.py delete mode 100644 neps/utils/_locker.py delete mode 100644 tests/test_state/test_synced.py diff --git a/neps/env.py b/neps/env.py index cd37d4df..c614ebac 100644 --- a/neps/env.py +++ b/neps/env.py @@ -33,7 +33,6 @@ def is_nullable(e: str) -> bool: parse=str, default="lockf", ) -assert LINUX_FILELOCK_FUNCTION in ("lockf", "flock") MAX_RETRIES_GET_NEXT_TRIAL = get_env( "NEPS_MAX_RETRIES_GET_NEXT_TRIAL", @@ -66,35 +65,17 @@ def is_nullable(e: str) -> bool: default=120, ) -SEED_SNAPSHOT_FILELOCK_POLL = get_env( - "NEPS_SEED_SNAPSHOT_FILELOCK_POLL", +# NOTE: We want this to be greater than the trials filelock, so that +# anything requesting to just update the trials is more likely to obtain it +# as those operations tend to be faster than something that requires optimizer +# state. +STATE_FILELOCK_POLL = get_env( + "NEPS_STATE_FILELOCK_POLL", parse=float, - default=0.05, -) -SEED_SNAPSHOT_FILELOCK_TIMEOUT = get_env( - "NEPS_SEED_SNAPSHOT_FILELOCK_TIMEOUT", - parse=lambda e: None if is_nullable(e) else float(e), - default=120, -) - -OPTIMIZER_INFO_FILELOCK_POLL = get_env( - "NEPS_OPTIMIZER_INFO_FILELOCK_POLL", - parse=float, - default=0.05, -) -OPTIMIZER_INFO_FILELOCK_TIMEOUT = get_env( - "NEPS_OPTIMIZER_INFO_FILELOCK_TIMEOUT", - parse=lambda e: None if is_nullable(e) else float(e), - default=120, -) - -OPTIMIZER_STATE_FILELOCK_POLL = get_env( - "NEPS_OPTIMIZER_STATE_FILELOCK_POLL", - parse=float, - default=0.05, + default=0.20, ) -OPTIMIZER_STATE_FILELOCK_TIMEOUT = get_env( - "NEPS_OPTIMIZER_STATE_FILELOCK_TIMEOUT", +STATE_FILELOCK_TIMEOUT = get_env( + "NEPS_STATE_FILELOCK_TIMEOUT", parse=lambda e: None if is_nullable(e) else float(e), default=120, ) diff --git a/neps/plot/tensorboard_eval.py b/neps/plot/tensorboard_eval.py index 380ad6b4..816f16eb 100644 --- a/neps/plot/tensorboard_eval.py +++ b/neps/plot/tensorboard_eval.py @@ -94,7 +94,7 @@ def _initiate_internal_configurations() -> None: register_notify_trial_end("NEPS_TBLOGGER", tblogger.end_of_config) # We are assuming that neps state is all filebased here - root_dir = Path(neps_state.location) + root_dir = Path(neps_state.path) assert root_dir.exists() tblogger.config_working_directory = Path(trial.metadata.location) diff --git a/neps/runtime.py b/neps/runtime.py index fb3e1e9b..308a2560 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -33,14 +33,13 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings +from neps.state.trial import Trial if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer - from neps.state.neps_state import NePSState - from neps.state.trial import Trial logger = logging.getLogger(__name__) @@ -64,7 +63,7 @@ def _default_worker_name() -> str: # TODO: This only works with a filebased nepsstate -def get_workers_neps_state() -> NePSState[Path]: +def get_workers_neps_state() -> NePSState: """Get the worker's NePS state.""" if _WORKER_NEPS_STATE is None: raise RuntimeError( @@ -76,7 +75,7 @@ def get_workers_neps_state() -> NePSState[Path]: return _WORKER_NEPS_STATE -def _set_workers_neps_state(state: NePSState[Path]) -> None: +def _set_workers_neps_state(state: NePSState) -> None: global _WORKER_NEPS_STATE # noqa: PLW0603 _WORKER_NEPS_STATE = state @@ -177,27 +176,7 @@ def new( _pre_sample_hooks=_pre_sample_hooks, ) - def _get_next_trial_from_state(self) -> Trial: - nxt_trial = self.state.get_next_pending_trial() - - # If we have a trial, we will use it - if nxt_trial is not None: - logger.info( - f"Worker '{self.worker_id}' got previosly sampled trial: {nxt_trial}" - ) - - # Otherwise sample a new one - else: - nxt_trial = self.state.sample_trial( - worker_id=self.worker_id, - optimizer=self.optimizer, - _sample_hooks=self._pre_sample_hooks, - ) - logger.info(f"Worker '{self.worker_id}' sampled a new trial: {nxt_trial}") - - return nxt_trial - - def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 + def _check_worker_local_settings( self, *, time_monotonic_start: float, @@ -205,8 +184,6 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 ) -> str | Literal[False]: # NOTE: Sorry this code is kind of ugly but it's pretty straightforward, just a # lot of conditional checking and making sure to check cheaper conditions first. - # It would look a little nicer with a match statement but we've got to wait - # for python 3.10 for that. # First check for stopping criterion for this worker in particular as it's # cheaper and doesn't require anything from the state. @@ -280,13 +257,16 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 f", given by `{self.settings.max_evaluation_time_for_worker_seconds=}`." ) + return False + + def _check_shared_error_stopping_criterion(self) -> str | Literal[False]: # We check this global error stopping criterion as it's much # cheaper than sweeping the state from all trials. if self.settings.on_error in ( OnErrorPossibilities.RAISE_ANY_ERROR, OnErrorPossibilities.STOP_ANY_ERROR, ): - err = self.state._shared_errors.synced().latest_err_as_raisable() + err = self.state.lock_and_get_errors().latest_err_as_raisable() if err is not None: msg = ( "An error occurred in another worker and this worker is set to stop" @@ -306,20 +286,12 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return msg - # If there are no global stopping criterion, we can no just return early. - if ( - self.settings.max_evaluations_total is None - and self.settings.max_cost_total is None - and self.settings.max_evaluation_time_total_seconds is None - ): - return False - - # At this point, if we have some global stopping criterion, we need to sweep - # the current state of trials to determine if we should stop - # NOTE: If these `sum` turn out to somehow be a bottleneck, these could - # be precomputed and accumulated over time. This would have to be handled - # in the `NePSState` class. - trials = self.state.get_all_trials() + return False + + def _check_global_stopping_criterion( + self, + trials: Mapping[str, Trial], + ) -> str | Literal[False]: if self.settings.max_evaluations_total is not None: if self.settings.include_in_progress_evaluations_towards_maximum: # NOTE: We can just use the sum of trials in this case as they @@ -368,6 +340,8 @@ def _check_if_should_stop( # noqa: C901, PLR0912, PLR0911 return False + # Forgive me lord, for I have sinned, this function is atrocious but complicated + # due to locking. def run(self) -> None: # noqa: C901, PLR0915, PLR0912 """Run the worker. @@ -385,18 +359,27 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 n_failed_set_trial_state = 0 n_repeated_failed_check_should_stop = 0 while True: - # NOTE: We rely on this function to do logging and raising errors if it should try: - should_stop = self._check_if_should_stop( + # First check local worker settings + should_stop = self._check_worker_local_settings( time_monotonic_start=_time_monotonic_start, error_from_this_worker=_error_from_evaluation, ) if should_stop is not False: logger.info(should_stop) break + + # Next check global errs having occured + should_stop = self._check_shared_error_stopping_criterion() + if should_stop is not False: + logger.info(should_stop) + break + except WorkerRaiseError as e: + # If we raise a specific error, we should stop the worker raise e except Exception as e: + # An unknown exception, check our retry countk n_repeated_failed_check_should_stop += 1 if ( n_repeated_failed_check_should_stop @@ -415,8 +398,48 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 time.sleep(1) # Help stagger retries continue + # From here, we now begin sampling or getting the next pending trial. + # As the global stopping criterion requires us to check all trials, and + # needs to be in locked in-step with sampling try: - trial_to_eval = self._get_next_trial_from_state() + # If there are no global stopping criterion, we can no just return early. + with self.state.lock_for_sampling(): + trials = self.state._trials.latest() + + requires_checking_global_stopping_criterion = ( + self.settings.max_evaluations_total is not None + or self.settings.max_cost_total is not None + or self.settings.max_evaluation_time_total_seconds is not None + ) + if requires_checking_global_stopping_criterion: + should_stop = self._check_global_stopping_criterion(trials) + if should_stop is not False: + logger.info(should_stop) + break + + pending_trials = [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ] + if len(pending_trials) > 0: + earliest_pending = sorted( + pending_trials, + key=lambda t: t.metadata.time_sampled, + )[0] + earliest_pending.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) + self.state._trials.update_trial(earliest_pending) + trial_to_eval = earliest_pending + else: + sampled_trial = self.state._sample_trial( + optimizer=self.optimizer, + worker_id=self.worker_id, + ) + trial_to_eval = sampled_trial + _repeated_fail_get_next_trial_count = 0 except Exception as e: _repeated_fail_get_next_trial_count += 1 @@ -439,11 +462,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # If we can't set this working to evaluating, then just retry the loop try: - trial_to_eval.set_evaluating( - time_started=time.time(), - worker_id=self.worker_id, - ) - self.state.put_updated_trial(trial_to_eval) n_failed_set_trial_state = 0 except VersionMismatchError: n_failed_set_trial_state += 1 @@ -512,11 +530,12 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # We do not retry this, as if some other worker has # managed to manipulate this trial in the meantime, # then something has gone wrong - self.state.report_trial_evaluation( - trial=evaluated_trial, - report=report, - worker_id=self.worker_id, - ) + with self.state.lock_trials(): + self.state._report_trial_evaluation( + trial=evaluated_trial, + report=report, + worker_id=self.worker_id, + ) logger.debug("Config %s: %s", evaluated_trial.id, evaluated_trial.config) logger.debug("Loss %s: %s", evaluated_trial.id, report.loss) @@ -553,8 +572,9 @@ def _launch_runtime( # noqa: PLR0913 for _retry_count in range(MAX_RETRIES_CREATE_LOAD_STATE): try: - neps_state = create_or_load_filebased_neps_state( - directory=optimization_dir, + neps_state = NePSState.create_or_load( + path=optimization_dir, + load_only=False, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( budget=( @@ -613,7 +633,17 @@ def _launch_runtime( # noqa: PLR0913 # it's not directly advertised as a parameter/env variable or otherwise. import portalocker.portalocker as portalocker_lock_module - setattr(portalocker_lock_module, "LOCKER", LINUX_FILELOCK_FUNCTION) + try: + import fcntl + + if LINUX_FILELOCK_FUNCTION.lower() == "flock": + setattr(portalocker_lock_module, "LOCKER", fcntl.flock) + elif LINUX_FILELOCK_FUNCTION.lower() == "lockf": + setattr(portalocker_lock_module, "LOCKER", fcntl.lockf) + else: + pass + except ImportError: + pass worker = DefaultWorker.new( state=neps_state, diff --git a/neps/state/__init__.py b/neps/state/__init__.py index e870d656..6b190afb 100644 --- a/neps/state/__init__.py +++ b/neps/state/__init__.py @@ -1,23 +1,11 @@ from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -from neps.state.protocols import ( - Locker, - ReaderWriter, - Synced, - VersionedResource, - Versioner, -) from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial __all__ = [ - "Locker", "SeedSnapshot", - "Synced", "BudgetInfo", "OptimizationState", "OptimizerInfo", "Trial", - "ReaderWriter", - "Versioner", - "VersionedResource", ] diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 9394d0e1..3658d511 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -1,55 +1,22 @@ -"""This module houses the implementation of a NePSState that -does everything on the filesystem, i.e. locking, versioning and -storing/loading. - -The main components are: -* [`FileVersioner`][neps.state.filebased.FileVersioner]: A versioner that - stores a version tag on disk, usually for a resource like a Trial. -* [`FileLocker`][neps.state.filebased.FileLocker]: A locker that uses a file - to lock between processes. -* [`TrialRepoInDirectory`][neps.state.filebased.TrialRepoInDirectory]: A - repository of Trials that are stored in a directory. -* `ReaderWriterXXX`: Reader/writers for various resources NePSState needs -* [`load_filebased_neps_state`][neps.state.filebased.load_filebased_neps_state]: - A function to load a NePSState from a directory. -* [`create_filebased_neps_state`][neps.state.filebased.create_filebased_neps_state]: - A function to create a new NePSState in a directory. -""" - from __future__ import annotations import json import logging import pprint -from collections.abc import Iterable, Iterator +from collections.abc import Iterator from contextlib import contextmanager -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass from pathlib import Path from typing import ClassVar, Final, TypeVar -from typing_extensions import override -from uuid import uuid4 import numpy as np import portalocker as pl from neps.env import ( ENV_VARS_USED, - GLOBAL_ERR_FILELOCK_POLL, - GLOBAL_ERR_FILELOCK_TIMEOUT, - OPTIMIZER_INFO_FILELOCK_POLL, - OPTIMIZER_INFO_FILELOCK_TIMEOUT, - OPTIMIZER_STATE_FILELOCK_POLL, - OPTIMIZER_STATE_FILELOCK_TIMEOUT, - SEED_SNAPSHOT_FILELOCK_POLL, - SEED_SNAPSHOT_FILELOCK_TIMEOUT, - TRIAL_FILELOCK_POLL, - TRIAL_FILELOCK_TIMEOUT, ) -from neps.exceptions import NePSError from neps.state.err_dump import ErrDump -from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -from neps.state.protocols import Locker, ReaderWriter, Synced, TrialRepo, Versioner from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial from neps.utils.files import deserialize, serialize @@ -59,43 +26,16 @@ T = TypeVar("T") -def make_sha() -> str: - """Generate a str hex sha.""" - return uuid4().hex - - @dataclass -class FileVersioner(Versioner): - """A versioner that stores a version tag on disk.""" - - version_file: Path - - @override - def current(self) -> str | None: - if not self.version_file.exists(): - return None - return self.version_file.read_text() - - @override - def bump(self) -> str: - sha = make_sha() - self.version_file.write_text(sha) - return sha - - -@dataclass -class ReaderWriterTrial(ReaderWriter[Trial, Path]): +class ReaderWriterTrial: """ReaderWriter for Trial objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - CONFIG_FILENAME = "config.yaml" METADATA_FILENAME = "metadata.yaml" STATE_FILENAME = "state.txt" REPORT_FILENAME = "report.yaml" PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" - @override @classmethod def read(cls, directory: Path) -> Trial: config_path = directory / cls.CONFIG_FILENAME @@ -112,7 +52,6 @@ def read(cls, directory: Path) -> Trial: ), ) - @override @classmethod def write(cls, trial: Trial, directory: Path) -> None: config_path = directory / cls.CONFIG_FILENAME @@ -132,166 +71,13 @@ def write(cls, trial: Trial, directory: Path) -> None: serialize(asdict(trial.report), report_path) -_StaticReaderWriterTrial: Final = ReaderWriterTrial() - -CONFIG_PREFIX_LEN: Final = len("config_") +TrialReaderWriter: Final = ReaderWriterTrial() @dataclass -class TrialRepoInDirectory(TrialRepo[Path]): - """A repository of Trials that are stored in a directory.""" - - directory: Path - _cache: dict[str, Synced[Trial, Path]] = field(default_factory=dict) - - @override - def all_trial_ids(self) -> list[str]: - """List all the trial ids in this trial Repo.""" - return [ - config_path.name[CONFIG_PREFIX_LEN:] - for config_path in self.directory.iterdir() - if config_path.name.startswith("config_") and config_path.is_dir() - ] - - @override - def get_by_id( - self, - trial_id: str, - *, - lock_poll: float = TRIAL_FILELOCK_POLL, - lock_timeout: float | None = TRIAL_FILELOCK_TIMEOUT, - ) -> Synced[Trial, Path]: - """Get a Trial by its ID. - - !!! note - - This will **not** explicitly sync the trial and it is up to the caller - to do so. Most of the time, the caller should be a NePSState - object which will do that for you. However if the trial is not in the - cache, then it will be loaded from disk which requires syncing. - - Args: - trial_id: The ID of the trial to get. - lock_poll: The poll time for the file lock. - lock_timeout: The timeout for the file lock. - - Returns: - The trial with the given ID. - """ - trial = self._cache.get(trial_id) - if trial is not None: - return trial - - config_path = self.directory / f"config_{trial_id}" - if not config_path.exists(): - raise TrialRepo.TrialNotFoundError(trial_id, config_path) - - trial = Synced.load( - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - self._cache[trial_id] = trial - return trial - - @override - def put_new( - self, - trial: Trial, - *, - lock_poll: float = TRIAL_FILELOCK_POLL, - lock_timeout: float | None = TRIAL_FILELOCK_TIMEOUT, - ) -> Synced[Trial, Path]: - """Put a new Trial into the repository. - - Args: - trial: The trial to put. - lock_poll: The poll time for the file lock. - lock_timeout: The timeout for the file lock. - - Returns: - The synced trial. - - Raises: - TrialRepo.TrialAlreadyExistsError: If the trial already exists in the - repository. - """ - config_path = self.directory.absolute().resolve() / f"config_{trial.metadata.id}" - if config_path.exists(): - # This shouldn't exist, we load in the trial to see the current state of it - # to try determine wtf is going on for logging purposes. - try: - shared_trial = Synced.load( - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - already_existing_trial = shared_trial._unsynced() - extra_msg = ( - f"The existing trial is the following: {already_existing_trial}" - ) - except Exception: # noqa: BLE001 - extra_msg = "Failed to load the existing trial to provide more info." - - raise TrialRepo.TrialAlreadyExistsError( - f"Trial '{trial.metadata.id}' already exists as '{config_path}'." - f" Tried to put in the trial: {trial}." - f"\n{extra_msg}" - ) - - # HACK: We do this here as there is no way to know where a Trial will - # be located when it's created... - trial.metadata.location = str(config_path) - shared_trial = Synced.new( - data=trial, - location=config_path, - locker=FileLocker( - lock_path=config_path / ".lock", - poll=lock_poll, - timeout=lock_timeout, - ), - versioner=FileVersioner(version_file=config_path / ".version"), - reader_writer=_StaticReaderWriterTrial, - ) - self._cache[trial.metadata.id] = shared_trial - return shared_trial - - @override - def all(self) -> dict[str, Synced[Trial, Path]]: - """Get a dictionary of all the Trials in the repository. - - !!! note - See [`get_by_id()`][neps.state.filebased.TrialRepoInDirectory.get_by_id] - for notes on the trials syncing. - """ - return {trial_id: self.get_by_id(trial_id) for trial_id in self.all_trial_ids()} - - @override - def pending(self) -> Iterable[tuple[str, Trial]]: - pending = [ - (_id, trial, trial.metadata.time_sampled) - for (_id, t) in self.all().items() - if (trial := t.synced()).state == Trial.State.PENDING - ] - return iter((_id, t) for _id, t, _ in sorted(pending, key=lambda x: x[2])) - - -@dataclass -class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): +class ReaderWriterSeedSnapshot: """ReaderWriter for SeedSnapshot objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - # It seems like they're all uint32 but I can't be sure. PY_RNG_STATE_DTYPE: ClassVar = np.int64 @@ -301,7 +87,6 @@ class ReaderWriterSeedSnapshot(ReaderWriter[SeedSnapshot, Path]): TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.pt" SEED_INFO_FILENAME: ClassVar = "seed_info.json" - @override @classmethod def read(cls, directory: Path) -> SeedSnapshot: seedinfo_path = directory / cls.SEED_INFO_FILENAME @@ -350,7 +135,6 @@ def read(cls, directory: Path) -> SeedSnapshot: torch_cuda_rng=torch_cuda_rng, ) - @override @classmethod def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: seedinfo_path = directory / cls.SEED_INFO_FILENAME @@ -387,20 +171,16 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: @dataclass -class ReaderWriterOptimizerInfo(ReaderWriter[OptimizerInfo, Path]): +class ReaderWriterOptimizerInfo: """ReaderWriter for OptimizerInfo objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - INFO_FILENAME: ClassVar = "info.yaml" - @override @classmethod def read(cls, directory: Path) -> OptimizerInfo: info_path = directory / cls.INFO_FILENAME return OptimizerInfo(info=deserialize(info_path)) - @override @classmethod def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: info_path = directory / cls.INFO_FILENAME @@ -412,14 +192,11 @@ def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: # handle this. # TODO(eddiebergman): May also want to consider serializing budget into a seperate entity @dataclass -class ReaderWriterOptimizationState(ReaderWriter[OptimizationState, Path]): +class ReaderWriterOptimizationState: """ReaderWriter for OptimizationState objects.""" - CHEAP_LOCKLESS_READ: ClassVar = True - STATE_FILE_NAME: ClassVar = "state.yaml" - @override @classmethod def read(cls, directory: Path) -> OptimizationState: state_path = directory / cls.STATE_FILE_NAME @@ -431,7 +208,6 @@ def read(cls, directory: Path) -> OptimizationState: budget=budget, ) - @override @classmethod def write(cls, info: OptimizationState, directory: Path) -> None: info_path = directory / cls.STATE_FILE_NAME @@ -439,24 +215,20 @@ def write(cls, info: OptimizationState, directory: Path) -> None: @dataclass -class ReaderWriterErrDump(ReaderWriter[ErrDump, Path]): +class ReaderWriterErrDump: """ReaderWriter for shared error lists.""" - CHEAP_LOCKLESS_READ: ClassVar = True - - name: str - - @override - def read(self, directory: Path) -> ErrDump: - errors_path = directory / f"{self.name}-errors.jsonl" + @classmethod + def read(cls, directory: Path) -> ErrDump: + errors_path = directory / "errors.jsonl" with errors_path.open("r") as f: data = [json.loads(line) for line in f] return ErrDump([ErrDump.SerializableTrialError(**d) for d in data]) - @override - def write(self, err_dump: ErrDump, directory: Path) -> None: - errors_path = directory / f"{self.name}-errors.jsonl" + @classmethod + def write(cls, err_dump: ErrDump, directory: Path) -> None: + errors_path = directory / "errors.jsonl" with errors_path.open("w") as f: lines = [json.dumps(asdict(trial_err)) for trial_err in err_dump.errs] f.write("\n".join(lines)) @@ -466,7 +238,7 @@ def write(self, err_dump: ErrDump, directory: Path) -> None: @dataclass -class FileLocker(Locker): +class FileLocker: """File-based locker using `portalocker`. [`FileLocker`][neps.state.locker.file.FileLocker] implements @@ -482,18 +254,6 @@ class FileLocker(Locker): def __post_init__(self) -> None: self.lock_path = self.lock_path.resolve().absolute() - @override - def is_locked(self) -> bool: - if not self.lock_path.exists(): - return False - try: - with self.lock(fail_if_locked=True): - pass - return False - except pl.exceptions.LockException: - return True - - @override @contextmanager def lock( self, @@ -521,187 +281,3 @@ def lock( " environment variables to increase the timeout:" f"\n\n{pprint.pformat(ENV_VARS_USED)}" ) from e - - -def load_filebased_neps_state(directory: Path) -> NePSState[Path]: - """Load a NePSState from a directory. - - Args: - directory: The directory to load the state from. - - Returns: - The loaded NePSState. - - Raises: - FileNotFoundError: If no NePSState is found at the given directory. - """ - if not directory.exists(): - raise FileNotFoundError(f"No NePSState found at '{directory}'.") - directory.mkdir(parents=True, exist_ok=True) - config_dir = directory / "configs" - config_dir.mkdir(parents=True, exist_ok=True) - seed_dir = directory / ".seed_state" - seed_dir.mkdir(parents=True, exist_ok=True) - error_dir = directory / ".errors" - error_dir.mkdir(parents=True, exist_ok=True) - optimizer_state_dir = directory / ".optimizer_state" - optimizer_state_dir.mkdir(parents=True, exist_ok=True) - optimizer_info_dir = directory / ".optimizer_info" - optimizer_info_dir.mkdir(parents=True, exist_ok=True) - - return NePSState( - location=str(directory.absolute().resolve()), - _trials=TrialRepoInDirectory(config_dir), - _optimizer_info=Synced.load( - location=optimizer_info_dir, - versioner=FileVersioner(version_file=optimizer_info_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_info_dir / ".lock", - poll=OPTIMIZER_INFO_FILELOCK_POLL, - timeout=OPTIMIZER_INFO_FILELOCK_TIMEOUT, - ), - reader_writer=ReaderWriterOptimizerInfo(), - ), - _seed_state=Synced.load( - location=seed_dir, - reader_writer=ReaderWriterSeedSnapshot(), - versioner=FileVersioner(version_file=seed_dir / ".version"), - locker=FileLocker( - lock_path=seed_dir / ".lock", - poll=SEED_SNAPSHOT_FILELOCK_POLL, - timeout=SEED_SNAPSHOT_FILELOCK_TIMEOUT, - ), - ), - _shared_errors=Synced.load( - location=error_dir, - reader_writer=ReaderWriterErrDump("all"), - versioner=FileVersioner(version_file=error_dir / ".all.version"), - locker=FileLocker( - lock_path=error_dir / ".all.lock", - poll=GLOBAL_ERR_FILELOCK_POLL, - timeout=GLOBAL_ERR_FILELOCK_TIMEOUT, - ), - ), - _optimizer_state=Synced.load( - location=optimizer_state_dir, - reader_writer=ReaderWriterOptimizationState(), - versioner=FileVersioner(version_file=optimizer_state_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_state_dir / ".lock", - poll=OPTIMIZER_STATE_FILELOCK_POLL, - timeout=OPTIMIZER_STATE_FILELOCK_TIMEOUT, - ), - ), - ) - - -def create_or_load_filebased_neps_state( - directory: Path, - *, - optimizer_info: OptimizerInfo, - optimizer_state: OptimizationState, -) -> NePSState[Path]: - """Create a new NePSState in a directory or load the existing one - if it already exists. - - !!! warning - - We check that the optimizer info in the NePSState on disk matches - the one that is passed. However we do not lock this check so it - is possible that if two processes try to create a NePSState at the - same time, both with different optimizer infos, that one will fail - to create the NePSState. This is a limitation of the current design. - - In principal, we could allow multiple optimizers to be run and share - the same set of trials. - - Args: - directory: The directory to create the state in. - optimizer_info: The optimizer info to use. - optimizer_state: The optimizer state to use. - - Returns: - The NePSState. - - Raises: - NePSError: If the optimizer info on disk does not match the one provided. - """ - is_new = not directory.exists() - directory.mkdir(parents=True, exist_ok=True) - config_dir = directory / "configs" - config_dir.mkdir(parents=True, exist_ok=True) - seed_dir = directory / ".seed_state" - seed_dir.mkdir(parents=True, exist_ok=True) - error_dir = directory / ".errors" - error_dir.mkdir(parents=True, exist_ok=True) - optimizer_state_dir = directory / ".optimizer_state" - optimizer_state_dir.mkdir(parents=True, exist_ok=True) - optimizer_info_dir = directory / ".optimizer_info" - optimizer_info_dir.mkdir(parents=True, exist_ok=True) - - # We have to do one bit of sanity checking to ensure that the optimzier - # info on disk manages the one we have recieved, otherwise we are unsure which - # optimizer is being used. - # NOTE: We assume that we do not have to worry about a race condition - # here where we have two different NePSState objects with two different optimizer - # infos trying to be created at the same time. This avoids the need to lock to - # check the optimizer info. If this assumption changes, then we would have - # to first lock before we do this check - optimizer_info_reader_writer = ReaderWriterOptimizerInfo() - if not is_new: - existing_info = optimizer_info_reader_writer.read(optimizer_info_dir) - if existing_info != optimizer_info: - raise NePSError( - "The optimizer info on disk does not match the one provided." - f"\nOn disk: {existing_info}\nProvided: {optimizer_info}" - f"\n\nLoaded the one on disk from {optimizer_info_dir}." - ) - - return NePSState( - location=str(directory.absolute().resolve()), - _trials=TrialRepoInDirectory(config_dir), - _optimizer_info=Synced.new_or_load( - data=optimizer_info, # type: ignore - location=optimizer_info_dir, - versioner=FileVersioner(version_file=optimizer_info_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_info_dir / ".lock", - poll=OPTIMIZER_INFO_FILELOCK_POLL, - timeout=OPTIMIZER_INFO_FILELOCK_TIMEOUT, - ), - reader_writer=ReaderWriterOptimizerInfo(), - ), - _seed_state=Synced.new_or_load( - data=SeedSnapshot.new_capture(), - location=seed_dir, - reader_writer=ReaderWriterSeedSnapshot(), - versioner=FileVersioner(version_file=seed_dir / ".version"), - locker=FileLocker( - lock_path=seed_dir / ".lock", - poll=SEED_SNAPSHOT_FILELOCK_POLL, - timeout=SEED_SNAPSHOT_FILELOCK_TIMEOUT, - ), - ), - _shared_errors=Synced.new_or_load( - data=ErrDump(), - location=error_dir, - reader_writer=ReaderWriterErrDump("all"), - versioner=FileVersioner(version_file=error_dir / ".all.version"), - locker=FileLocker( - lock_path=error_dir / ".all.lock", - poll=GLOBAL_ERR_FILELOCK_POLL, - timeout=GLOBAL_ERR_FILELOCK_TIMEOUT, - ), - ), - _optimizer_state=Synced.new_or_load( - data=optimizer_state, - location=optimizer_state_dir, - reader_writer=ReaderWriterOptimizationState(), - versioner=FileVersioner(version_file=optimizer_state_dir / ".version"), - locker=FileLocker( - lock_path=optimizer_state_dir / ".lock", - poll=OPTIMIZER_STATE_FILELOCK_POLL, - timeout=OPTIMIZER_STATE_FILELOCK_TIMEOUT, - ), - ), - ) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 3d4f3186..39dc5265 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -11,70 +11,267 @@ from __future__ import annotations import logging +import pickle import time -from collections.abc import Callable +from collections.abc import Callable, Iterator +from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Generic, TypeVar, overload - -from more_itertools import take - -from neps.exceptions import TrialAlreadyExistsError +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Generic, + Literal, + TypeAlias, + TypeVar, + overload, +) +from uuid import uuid4 + +from neps.env import ( + STATE_FILELOCK_POLL, + STATE_FILELOCK_TIMEOUT, + TRIAL_FILELOCK_POLL, + TRIAL_FILELOCK_TIMEOUT, +) +from neps.exceptions import NePSError, TrialAlreadyExistsError, TrialNotFoundError from neps.state.err_dump import ErrDump +from neps.state.filebased import ( + FileLocker, + ReaderWriterErrDump, + ReaderWriterOptimizationState, + ReaderWriterOptimizerInfo, + ReaderWriterSeedSnapshot, + TrialReaderWriter, +) from neps.state.optimizer import OptimizationState, OptimizerInfo -from neps.state.trial import Trial +from neps.state.seed_snapshot import SeedSnapshot +from neps.state.trial import Report, Trial if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer - from neps.state.protocols import Synced, TrialRepo - from neps.state.seed_snapshot import SeedSnapshot logger = logging.getLogger(__name__) +N_UNSAFE_RETRIES = 10 # TODO: Technically we don't need the same Location type for all shared objects. Loc = TypeVar("Loc") T = TypeVar("T") +Version: TypeAlias = str + +Resource: TypeAlias = Literal[ + "optimizer_info", "optimizer_state", "seed_state", "errors", "configs" +] + + +def make_sha() -> Version: + """Generate a str hex sha.""" + return uuid4().hex + + +CONFIG_PREFIX_LEN = len("config_") + + +# TODO: Ergonomics of this class sucks +@dataclass +class TrialRepo: + directory: Path + version_file: Path + cache: dict[str, tuple[Trial, Version]] = field(default_factory=dict) + + def __post_init__(self) -> None: + self.directory.mkdir(parents=True, exist_ok=True) + + def list_trial_ids(self) -> list[str]: + return [ + config_path.name[CONFIG_PREFIX_LEN:] + for config_path in self.directory.iterdir() + if config_path.name.startswith("config_") and config_path.is_dir() + ] + + def latest(self) -> dict[str, Trial]: + if not self.version_file.exists(): + return {} + + with self.version_file.open("rb") as f: + versions_on_disk = pickle.load(f) # noqa: S301 + + stale = { + k: v + for k, v in versions_on_disk.items() + if self.cache.get(k, (None, "__not_found__")) != v + } + for trial_id, disk_version in stale.items(): + loaded_trial = self.load_trial_from_disk(trial_id) + self.cache[trial_id] = (loaded_trial, disk_version) + + return {k: v[0] for k, v in self.cache.items()} + + def new_trial(self, trial: Trial) -> None: + config_path = self.directory / f"config_{trial.id}" + if config_path.exists(): + raise TrialAlreadyExistsError(trial.id, config_path) + config_path.mkdir(parents=True, exist_ok=True) + self.update_trial(trial) + + def update_trial(self, trial: Trial) -> None: + new_version = make_sha() + TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}") + self.cache[trial.id] = (trial, new_version) + + def write_version_file(self) -> None: + with self.version_file.open("wb") as f: + pickle.dump({k: v[1] for k, v in self.cache.items()}, f) + + def trials_in_memory(self) -> dict[str, Trial]: + return {k: v[0] for k, v in self.cache.items()} + + def load_trial_from_disk(self, trial_id: str) -> Trial: + config_path = self.directory / f"config_{trial_id}" + if not config_path.exists(): + raise TrialNotFoundError(trial_id, config_path) + + return TrialReaderWriter.read(config_path) + @dataclass -class NePSState(Generic[Loc]): +class VersionedResource(Generic[T]): + resource: T + path: Path + read: Callable[[Path], T] + write: Callable[[T, Path], None] + version_file: Path + version: Version = "__not_yet_written__" + + def latest(self) -> T: + if not self.version_file.exists(): + return self.resource + + file_version = self.version_file.read_text() + if self.version == file_version: + return self.resource + + self.resource = self.read(self.path) + self.version = file_version + return self.resource + + def update(self, new_resource: T) -> Version: + self.resource = new_resource + self.version = make_sha() + self.version_file.write_text(self.version) + self.write(new_resource, self.path) + return self.version + + @classmethod + def new( + cls, + resource: T, + path: Path, + read: Callable[[Path], T], + write: Callable[[T, Path], None], + version_file: Path, + ) -> VersionedResource[T]: + if version_file.exists(): + raise FileExistsError(f"Version file already exists at '{version_file}'.") + + write(resource, path) + version = make_sha() + version_file.write_text(version) + return cls( + resource=resource, + path=path, + read=read, + write=write, + version_file=version_file, + version=version, + ) + + @classmethod + def load( + cls, + path: Path, + *, + read: Callable[[Path], T], + write: Callable[[T, Path], None], + version_file: Path, + ) -> VersionedResource[T]: + if not path.exists(): + raise FileNotFoundError(f"Resource not found at '{path}'.") + + return cls( + resource=read(path), + path=path, + read=read, + write=write, + version_file=version_file, + version=version_file.read_text(), + ) + + +@dataclass +class NePSState: """The main state object that holds all the shared state objects.""" - location: str + path: Path - _trials: TrialRepo[Loc] = field(repr=False) - _optimizer_info: Synced[OptimizerInfo, Loc] - _seed_state: Synced[SeedSnapshot, Loc] = field(repr=False) - _optimizer_state: Synced[OptimizationState, Loc] - _shared_errors: Synced[ErrDump, Loc] = field(repr=False) + _trial_lock: FileLocker = field(repr=False) + _trials: TrialRepo = field(repr=False) - def put_updated_trial(self, trial: Trial, /) -> None: - """Update the trial with the new information. + _state_lock: FileLocker = field(repr=False) + _optimizer_info: VersionedResource[OptimizerInfo] = field(repr=False) + _seed_snapshot: VersionedResource[SeedSnapshot] = field(repr=False) + _optimizer_state: VersionedResource[OptimizationState] = field(repr=False) - Args: - trial: The trial to update. + _err_lock: FileLocker = field(repr=False) + _shared_errors: VersionedResource[ErrDump] = field(repr=False) - Raises: - VersionMismatchError: If the trial has been updated since it was last - fetched by the worker using this state. This indicates that some other - worker has updated the trial in the meantime and the changes from - this worker are rejected. - """ - shared_trial = self._trials.get_by_id(trial.id) - shared_trial.put(trial) + @contextmanager + def lock_for_sampling(self) -> Iterator[None]: + """Acquire the state lock and trials lock.""" + with self._state_lock.lock(), self._trial_lock.lock(): + yield - def get_trial_by_id(self, trial_id: str, /) -> Trial: - """Get a trial by its id.""" - return self._trials.get_by_id(trial_id).synced() + @contextmanager + def lock_trials(self) -> Iterator[None]: + """Acquire the state lock.""" + with self._trial_lock.lock(): + yield + + def lock_and_read_trials(self) -> dict[str, Trial]: + """Acquire the state lock and read the trials.""" + with self._trial_lock.lock(): + return self._trials.latest() - def sample_trial( + def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: + """Acquire the state lock and sample a trial.""" + with self.lock_for_sampling(): + return self._sample_trial(optimizer, worker_id=worker_id) + + def lock_and_report_trial_evaluation( + self, + trial: Trial, + report: Report, + *, + worker_id: str, + ) -> None: + """Acquire the state lock and report the trial evaluation.""" + with self._trial_lock.lock(), self._err_lock.lock(): + self._report_trial_evaluation(trial, report, worker_id=worker_id) + + def _sample_trial( self, optimizer: BaseOptimizer, *, worker_id: str, _sample_hooks: list[Callable] | None = None, + _trials: dict[str, Trial] | None = None, ) -> Trial: """Sample a new trial from the optimizer. + !!! warning + + Responsibility of locking is on caller. + Args: optimizer: The optimizer to sample the trial from. worker_id: The worker that is sampling the trial. @@ -83,110 +280,100 @@ def sample_trial( Returns: The new trial. """ - with ( - self._optimizer_state.acquire() as (opt_state, put_opt), - self._seed_state.acquire() as (seed_state, put_seed_state), - ): - # NOTE: We make the assumption that as we have acquired the optimizer - # state, there is not possibility of another trial being created between - # the time we read in the trials below and `ask()`ing for the next trials - # from the optimizer. If so, that means there is another source of trial - # generation that occurs outside of this function and outside the scope - # of acquiring the optimizer_state lock. - trials: dict[str, Trial] = { - trial_id: shared_trial.synced() - for trial_id, shared_trial in list(self._trials.all().items()) - } - - seed_state.set_as_global_seed_state() - - # TODO: Not sure if any existing pre_load hooks required - # it to be done after `load_results`... I hope not. - if _sample_hooks is not None: - for hook in _sample_hooks: - optimizer = hook(optimizer) - - # NOTE: Re-work this, as the part's that are recomputed - # do not need to be serialized - budget = opt_state.budget - if budget is not None: - budget = budget.clone() - - # NOTE: All other values of budget are ones that should remain - # constant, there are currently only these two which are dynamic as - # optimization unfold - budget.used_cost_budget = sum( - trial.report.cost - for trial in trials.values() - if trial.report is not None and trial.report.cost is not None - ) - budget.used_evaluations = len(trials) - - sampled_config_maybe_new_opt_state = optimizer.ask( - trials=trials, - budget_info=budget, + trials = self._trials.latest() if _trials is None else _trials + seed_state = self._seed_snapshot.latest() + opt_state = self._optimizer_state.latest() + + seed_state.set_as_global_seed_state() + + # TODO: Not sure if any existing pre_load hooks required + # it to be done after `load_results`... I hope not. + if _sample_hooks is not None: + for hook in _sample_hooks: + optimizer = hook(optimizer) + + # NOTE: Re-work this, as the part's that are recomputed + # do not need to be serialized + budget = opt_state.budget + if budget is not None: + budget = budget.clone() + + # NOTE: All other values of budget are ones that should remain + # constant, there are currently only these two which are dynamic as + # optimization unfold + budget.used_cost_budget = sum( + trial.report.cost + for trial in trials.values() + if trial.report is not None and trial.report.cost is not None ) - - if isinstance(sampled_config_maybe_new_opt_state, tuple): - sampled_config, new_opt_state = sampled_config_maybe_new_opt_state - else: - sampled_config = sampled_config_maybe_new_opt_state - new_opt_state = opt_state.shared_state - - if sampled_config.previous_config_id is not None: - previous_trial = trials.get(sampled_config.previous_config_id) - if previous_trial is None: - raise ValueError( - f"Previous trial '{sampled_config.previous_config_id}' not found." - ) - previous_trial_location = previous_trial.metadata.location + budget.used_evaluations = len(trials) + + sampled_config_maybe_new_opt_state = optimizer.ask( + trials=trials, + budget_info=budget, + ) + + if isinstance(sampled_config_maybe_new_opt_state, tuple): + sampled_config, new_opt_state = sampled_config_maybe_new_opt_state + else: + sampled_config = sampled_config_maybe_new_opt_state + new_opt_state = opt_state.shared_state + + if sampled_config.previous_config_id is not None: + previous_trial = trials.get(sampled_config.previous_config_id) + if previous_trial is None: + raise ValueError( + f"Previous trial '{sampled_config.previous_config_id}' not found." + ) + previous_trial_location = previous_trial.metadata.location + else: + previous_trial_location = None + + trial = Trial.new( + trial_id=sampled_config.id, + location="", # HACK: This will be set by the `TrialRepo` in `put_new` + config=sampled_config.config, + previous_trial=sampled_config.previous_config_id, + previous_trial_location=previous_trial_location, + time_sampled=time.time(), + worker_id=worker_id, + ) + try: + self._trials.new_trial(trial) + self._trials.write_version_file() + except TrialAlreadyExistsError as e: + if sampled_config.id in trials: + logger.warning( + "The new sampled trial was given an id of '%s', yet this already" + " exists in the loaded in trials given to the optimizer. This" + " indicates a bug with the optimizers allocation of ids.", + sampled_config.id, + ) else: - previous_trial_location = None - - trial = Trial.new( - trial_id=sampled_config.id, - location="", # HACK: This will be set by the `TrialRepo` in `put_new` - config=sampled_config.config, - previous_trial=sampled_config.previous_config_id, - previous_trial_location=previous_trial_location, - time_sampled=time.time(), - worker_id=worker_id, - ) - try: - self._trials.put_new(trial) - except TrialAlreadyExistsError as e: - if sampled_config.id in trials: - logger.warning( - "The new sampled trial was given an id of '%s', yet this already" - " exists in the loaded in trials given to the optimizer. This" - " indicates a bug with the optimizers allocation of ids.", - sampled_config.id, - ) - else: - logger.warning( - "The new sampled trial was given an id of '%s', which is not one" - " that was loaded in by the optimizer. This indicates that" - " configuration '%s' was put on disk during the time that this" - " worker had the optimizer state lock OR that after obtaining the" - " optimizer state lock, somehow this configuration failed to be" - " loaded in and passed to the optimizer.", - sampled_config.id, - sampled_config.id, - ) - raise e + logger.warning( + "The new sampled trial was given an id of '%s', which is not one" + " that was loaded in by the optimizer. This indicates that" + " configuration '%s' was put on disk during the time that this" + " worker had the optimizer state lock OR that after obtaining the" + " optimizer state lock, somehow this configuration failed to be" + " loaded in and passed to the optimizer.", + sampled_config.id, + sampled_config.id, + ) + raise e - seed_state.recapture() - put_seed_state(seed_state) - put_opt( - OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) - ) + seed_state.recapture() + self._seed_snapshot.update(seed_state) + self._optimizer_state.update( + OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) + ) return trial - def report_trial_evaluation( + def _report_trial_evaluation( self, trial: Trial, - report: Trial.Report, + report: Report, *, worker_id: str, ) -> None: @@ -199,61 +386,248 @@ def report_trial_evaluation( optimizer: The optimizer to update and get the state from worker_id: The worker that evaluated the trial. """ - shared_trial = self._trials.get_by_id(trial.id) - # TODO: This would fail if some other worker has already updated the trial. - # IMPORTANT: We need to attach the report to the trial before updating the things. trial.report = report - shared_trial.put(trial) + self._trials.update_trial(trial) + self._trials.write_version_file() + logger.debug("Updated trial '%s' with status '%s'", trial.id, trial.state) if report.err is not None: - with self._shared_errors.acquire() as (errs, put_errs): - trial_err = ErrDump.SerializableTrialError( - trial_id=trial.id, - worker_id=worker_id, - err_type=type(report.err).__name__, - err=str(report.err), - tb=report.tb, + with self._err_lock.lock(): + err_dump = self._shared_errors.latest() + err_dump.errs.append( + ErrDump.SerializableTrialError( + trial_id=trial.id, + worker_id=worker_id, + err_type=type(report.err).__name__, + err=str(report.err), + tb=report.tb, + ) ) - errs.append(trial_err) - put_errs(errs) + self._shared_errors.update(err_dump) + + def all_trial_ids(self) -> list[str]: + """Get all the trial ids.""" + return self._trials.list_trial_ids() - def get_errors(self) -> ErrDump: + def lock_and_get_errors(self) -> ErrDump: """Get all the errors that have occurred during the optimization.""" - return self._shared_errors.synced() + with self._err_lock.lock(): + return self._shared_errors.latest() + + def lock_and_get_optimizer_info(self) -> OptimizerInfo: + """Get the optimizer information.""" + with self._state_lock.lock(): + return self._optimizer_info.latest() + + def lock_and_get_optimizer_state(self) -> OptimizationState: + """Get the optimizer state.""" + with self._state_lock.lock(): + return self._optimizer_state.latest() + + def lock_and_get_trial_by_id(self, trial_id: str) -> Trial: + """Get a trial by its id.""" + with self._trial_lock.lock(): + return self._trials.load_trial_from_disk(trial_id) + + def unsafe_retry_get_trial_by_id(self, trial_id: str) -> Trial: + """Get a trial by id but use unsafe retries.""" + for _ in range(N_UNSAFE_RETRIES): + try: + return self._trials.load_trial_from_disk(trial_id) + except TrialNotFoundError as e: + raise e + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to get trial '%s' due to an error: %s", trial_id, e + ) + time.sleep(0.1) + continue + + raise NePSError( + f"Failed to get trial '{trial_id}' after {N_UNSAFE_RETRIES} retries." + ) + + def put_updated_trial(self, trial: Trial) -> None: + """Update the trial.""" + with self._trial_lock.lock(): + self._trials.update_trial(trial) + self._trials.write_version_file() @overload - def get_next_pending_trial(self) -> Trial | None: ... + def lock_and_get_next_pending_trial(self) -> Trial | None: ... @overload - def get_next_pending_trial(self, n: int | None = None) -> list[Trial]: ... + def lock_and_get_next_pending_trial(self, n: int | None = None) -> list[Trial]: ... + + def lock_and_get_next_pending_trial( + self, + n: int | None = None, + ) -> Trial | list[Trial] | None: + """Get the next pending trial.""" + with self._trial_lock.lock(): + trials = self._trials.latest() + pendings = sorted( + [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ], + key=lambda t: t.metadata.time_sampled, + ) + if n is None: + return pendings[0] if pendings else None + return pendings[:n] + + @classmethod + def create_or_load( + cls, + path: Path, + *, + load_only: bool = False, + optimizer_info: OptimizerInfo | None = None, + optimizer_state: OptimizationState | None = None, + ) -> NePSState: + """Create a new NePSState in a directory or load the existing one + if it already exists, depending on the argument. - def get_next_pending_trial(self, n: int | None = None) -> Trial | list[Trial] | None: - """Get the next pending trial to evaluate. + !!! warning - Args: - n: The number of trials to get. If `None`, get the next trial. + We check that the optimizer info in the NePSState on disk matches + the one that is passed. However we do not lock this check so it + is possible that if two processes try to create a NePSState at the + same time, both with different optimizer infos, that one will fail + to create the NePSState. This is a limitation of the current design. - Returns: - The next trial or a list of trials if `n` is not `None`. - """ - _pending_itr = (shared_trial for _, shared_trial in self._trials.pending()) - if n is not None: - return take(n, _pending_itr) - return next(_pending_itr, None) + In principal, we could allow multiple optimizers to be run and share + the same set of trials. - def all_trial_ids(self) -> list[str]: - """Get all the trial ids that are known about.""" - return self._trials.all_trial_ids() + Args: + path: The directory to create the state in. + load_only: If True, only load the state and do not create a new one. + optimizer_info: The optimizer info to use. + optimizer_state: The optimizer state to use. - def get_all_trials(self) -> dict[str, Trial]: - """Get all the trials that are known about.""" - return {_id: trial.synced() for _id, trial in self._trials.all().items()} + Returns: + The NePSState. - def optimizer_info(self) -> OptimizerInfo: - """Get the optimizer information.""" - return self._optimizer_info.synced() + Raises: + NePSError: If the optimizer info on disk does not match the one provided. + """ + is_new = not path.exists() + if load_only: + if is_new: + raise FileNotFoundError(f"No NePSState found at '{path}'.") + else: + assert optimizer_info is not None + assert optimizer_state is not None + + path.mkdir(parents=True, exist_ok=True) + config_dir = path / "configs" + config_dir.mkdir(parents=True, exist_ok=True) + seed_dir = path / ".seed_state" + seed_dir.mkdir(parents=True, exist_ok=True) + error_dir = path / ".errors" + error_dir.mkdir(parents=True, exist_ok=True) + optimizer_state_dir = path / ".optimizer_state" + optimizer_state_dir.mkdir(parents=True, exist_ok=True) + optimizer_info_dir = path / ".optimizer_info" + optimizer_info_dir.mkdir(parents=True, exist_ok=True) + + # We have to do one bit of sanity checking to ensure that the optimzier + # info on disk manages the one we have recieved, otherwise we are unsure which + # optimizer is being used. + # NOTE: We assume that we do not have to worry about a race condition + # here where we have two different NePSState objects with two different optimizer + # infos trying to be created at the same time. This avoids the need to lock to + # check the optimizer info. If this assumption changes, then we would have + # to first lock before we do this check + if not is_new: + _optimizer_info = VersionedResource.load( + optimizer_info_dir, + read=ReaderWriterOptimizerInfo.read, + write=ReaderWriterOptimizerInfo.write, + version_file=optimizer_info_dir / ".version", + ) + _optimizer_state = VersionedResource.load( + optimizer_state_dir, + read=ReaderWriterOptimizationState.read, + write=ReaderWriterOptimizationState.write, + version_file=optimizer_state_dir / ".version", + ) + _seed_snapshot = VersionedResource.load( + seed_dir, + read=ReaderWriterSeedSnapshot.read, + write=ReaderWriterSeedSnapshot.write, + version_file=seed_dir / ".version", + ) + _shared_errors = VersionedResource.load( + error_dir, + read=ReaderWriterErrDump.read, + write=ReaderWriterErrDump.write, + version_file=error_dir / ".version", + ) + existing_info = _optimizer_info.latest() + if not load_only and existing_info != optimizer_info: + raise NePSError( + "The optimizer info on disk does not match the one provided." + f"\nOn disk: {existing_info}\nProvided: {optimizer_info}" + f"\n\nLoaded the one on disk from {optimizer_info_dir}." + ) + else: + assert optimizer_info is not None + assert optimizer_state is not None + _optimizer_info = VersionedResource.new( + resource=optimizer_info, + path=optimizer_info_dir, + read=ReaderWriterOptimizerInfo.read, + write=ReaderWriterOptimizerInfo.write, + version_file=optimizer_info_dir / ".version", + ) + _optimizer_state = VersionedResource.new( + resource=optimizer_state, + path=optimizer_state_dir, + read=ReaderWriterOptimizationState.read, + write=ReaderWriterOptimizationState.write, + version_file=optimizer_state_dir / ".version", + ) + _seed_snapshot = VersionedResource.new( + resource=SeedSnapshot.new_capture(), + path=seed_dir, + read=ReaderWriterSeedSnapshot.read, + write=ReaderWriterSeedSnapshot.write, + version_file=seed_dir / ".version", + ) + _shared_errors = VersionedResource.new( + resource=ErrDump(), + path=error_dir, + read=ReaderWriterErrDump.read, + write=ReaderWriterErrDump.write, + version_file=error_dir / ".version", + ) - def optimizer_state(self) -> OptimizationState: - """Get the optimizer state.""" - return self._optimizer_state.synced() + return cls( + path=path, + _trials=TrialRepo(config_dir, version_file=config_dir / ".versions"), + # Locks, + _trial_lock=FileLocker( + lock_path=path / ".configs.lock", + poll=TRIAL_FILELOCK_POLL, + timeout=TRIAL_FILELOCK_TIMEOUT, + ), + _state_lock=FileLocker( + lock_path=path / ".state.lock", + poll=STATE_FILELOCK_POLL, + timeout=STATE_FILELOCK_TIMEOUT, + ), + _err_lock=FileLocker( + lock_path=error_dir / "errors.lock", + poll=TRIAL_FILELOCK_POLL, + timeout=TRIAL_FILELOCK_TIMEOUT, + ), + # State + _optimizer_info=_optimizer_info, + _optimizer_state=_optimizer_state, + _seed_snapshot=_seed_snapshot, + _shared_errors=_shared_errors, + ) diff --git a/neps/state/protocols.py b/neps/state/protocols.py deleted file mode 100644 index 7bbe7a9a..00000000 --- a/neps/state/protocols.py +++ /dev/null @@ -1,577 +0,0 @@ -"""This module defines the protocols used by -[`NePSState`][neps.state.neps_state.NePSState] and -[`Synced`][neps.state.synced.Synced] to ensure atomic operations to the state itself. -""" - -from __future__ import annotations - -import logging -from collections.abc import Callable, Iterable, Iterator -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar -from typing_extensions import Self - -from neps.exceptions import ( - LockFailedError, - TrialAlreadyExistsError, - TrialNotFoundError, - VersionedResourceAlreadyExistsError, - VersionedResourceDoesNotExistError, - VersionedResourceRemovedError, - VersionMismatchError, -) - -if TYPE_CHECKING: - from neps.state import Trial - -logger = logging.getLogger(__name__) - -T = TypeVar("T") -K = TypeVar("K") - -# https://github.com/MaT1g3R/option/issues/40 -K2 = TypeVar("K2") -T2 = TypeVar("T2") - -Loc_contra = TypeVar("Loc_contra", contravariant=True) - - -class Versioner(Protocol): - """A versioner that can bump the version of a resource. - - It should have some [`current()`][neps.state.protocols.Versioner.current] method - to give the current version tag of a resource and a - [`bump()`][neps.state.protocols.Versioner.bump] method to provide a new version tag. - - These [`current()`][neps.state.protocols.Versioner.current] and - [`bump()`][neps.state.protocols.Versioner.bump] methods do not need to be atomic - but they should read/write to external state, i.e. file-system, database, etc. - """ - - def current(self) -> str | None: - """Return the current version as defined by the external state, i.e. - the version of the tag on disk. - - Returns: - The current version if there is one written. - """ - ... - - def bump(self) -> str: - """Create a new external version tag. - - Returns: - The new version tag. - """ - ... - - -class Locker(Protocol): - """A locker that can be used to communicate between workers.""" - - LockFailedError: ClassVar = LockFailedError - - @contextmanager - def lock(self) -> Iterator[None]: - """Initiate the lock as a context manager, releasing it when done.""" - ... - - def is_locked(self) -> bool: - """Check if lock is...well, locked. - - Should return True if the resource is locked, even if the lock is held by the - current worker/process. - """ - ... - - -class ReaderWriter(Protocol[T, Loc_contra]): - """A reader-writer that can read and write some resource T with location Loc. - - For example, a `ReaderWriter[Trial, Path]` indicates a class that can read and write - trials, given some `Path`. - """ - - CHEAP_LOCKLESS_READ: ClassVar[bool] - """Whether reading the contents of the resource is cheap, cheap enough to be - most likely safe without a lock if outdated information is acceptable. - - This is currently used to help debugging instances of a VersionMismatchError - to see what the current state is and what was attempted to be written. - """ - - def read(self, loc: Loc_contra, /) -> T: - """Read the resource at the given location.""" - ... - - def write(self, value: T, loc: Loc_contra, /) -> None: - """Write the resource at the given location.""" - ... - - -class TrialRepo(Protocol[K]): - """A repository of trials. - - The primary purpose of this protocol is to ensure consistent access to trial, - the ability to put in a new trial and know about the trials that are stored there. - """ - - TrialAlreadyExistsError: ClassVar = TrialAlreadyExistsError - TrialNotFoundError: ClassVar = TrialNotFoundError - - def all_trial_ids(self) -> list[str]: - """List all the trial ids in this trial Repo.""" - ... - - def get_by_id(self, trial_id: str) -> Synced[Trial, K]: - """Get a trial by its id.""" - ... - - def put_new(self, trial: Trial) -> Synced[Trial, K]: - """Put a new trial in the repo.""" - ... - - def all(self) -> dict[str, Synced[Trial, K]]: - """Get all trials in the repo.""" - ... - - def pending(self) -> Iterable[tuple[str, Trial]]: - """Get all pending trials in the repo. - - !!! note - This should return trials in the order in which they should be next evaluated, - usually the order in which they were put in the repo. - """ - ... - - -@dataclass -class VersionedResource(Generic[T, K]): - """A resource that will be read if it needs to update to the latest version. - - Relies on 3 main components: - * A [`Versioner`][neps.state.protocols.Versioner] to manage the versioning of the - resource. - * A [`ReaderWriter`][neps.state.protocols.ReaderWriter] to read and write the - resource. - * The location of the resource that can be used for the reader-writer. - """ - - VersionMismatchError: ClassVar = VersionMismatchError - VersionedResourceDoesNotExistsError: ClassVar = VersionedResourceDoesNotExistError - VersionedResourceAlreadyExistsError: ClassVar = VersionedResourceAlreadyExistsError - VersionedResourceRemovedError: ClassVar = VersionedResourceRemovedError - - _current: T - _location: K - _version: str - _versioner: Versioner - _reader_writer: ReaderWriter[T, K] - - @staticmethod - def new( - *, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> VersionedResource[T2, K2]: - """Create a new VersionedResource. - - This will create a new resource if it doesn't exist, otherwise, - if it already exists, it will raise an error. - - Use [`load()`][neps.state.protocols.VersionedResource.load] if you want to - load an existing resource. - - Args: - data: The data to be stored. - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A new VersionedResource - - Raises: - VersionedResourceAlreadyExistsError: If a versioned resource already exists - at the given location. - """ - current_version = versioner.current() - if current_version is not None: - raise VersionedResourceAlreadyExistsError( - f"A versioned resource already already exists at '{location}'" - f" with version '{current_version}'" - ) - - version = versioner.bump() - reader_writer.write(data, location) - return VersionedResource( - _current=data, - _location=location, - _version=version, - _versioner=versioner, - _reader_writer=reader_writer, - ) - - @classmethod - def load( - cls, - *, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> VersionedResource[T2, K2]: - """Load an existing VersionedResource. - - This will load an existing resource if it exists, otherwise, it will raise an - error. - - Use [`new()`][neps.state.protocols.VersionedResource.new] if you want to - create a new resource. - - Args: - location: The location of the resource. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A VersionedResource - - Raises: - VersionedResourceDoesNotExistsError: If no versioned resource exists at - the given location. - """ - version = versioner.current() - if version is None: - raise cls.VersionedResourceDoesNotExistsError( - f"No versioned resource exists at '{location}'." - ) - data = reader_writer.read(location) - return VersionedResource( - _current=data, - _location=location, - _version=version, - _versioner=versioner, - _reader_writer=reader_writer, - ) - - def sync_and_get(self) -> T: - """Get the data and version of the resource.""" - self.sync() - return self._current - - def sync(self) -> None: - """Sync the resource with the latest version.""" - current_version = self._versioner.current() - if current_version is None: - raise self.VersionedResourceRemovedError( - f"Versioned resource at '{self._location}' has been removed!" - f" Last known version was '{self._version}'." - ) - - if self._version != current_version: - self._current = self._reader_writer.read(self._location) - self._version = current_version - - def put(self, data: T) -> None: - """Put the data and version of the resource. - - Raises: - VersionMismatchError: If the version of the resource is not the same as the - current version. This implies that the resource has been updated by - another worker. - """ - current_version = self._versioner.current() - if self._version != current_version: - # We will attempt to do a lockless read on the contents of the items, as this - # would allow us to better debug in the error raised below. - if self._reader_writer.CHEAP_LOCKLESS_READ: - current_contents = self._reader_writer.read(self._location) - extra_msg = ( - f"\nThe attempted write was: {data}\n" - f"The current contents are: {current_contents}" - ) - else: - extra_msg = "" - - raise self.VersionMismatchError( - f"Version mismatch - ours: '{self._version}', remote: '{current_version}'" - f" Tried to put data at '{self._location}'. Doing so would overwrite" - " changes made by another worker. The solution is to pull the latest" - " version of the resource and try again." - " The most possible reasons for this error is that a lock was not" - " utilized when getting this resource before putting it back." - f"{extra_msg}" - ) - - self._reader_writer.write(data, self._location) - self._current = data - self._version = self._versioner.bump() - - def current(self) -> T: - """Get the current data of the resource.""" - return self._current - - def is_stale(self) -> bool: - """Check if the resource is stale.""" - return self._version != self._versioner.current() - - def location(self) -> K: - """Get the location of the resource.""" - return self._location - - -@dataclass -class Synced(Generic[T, K]): - """Manages a versioned resource but it's methods also implement locking procedures - for accessing it. - - Its types are parametrized by two type variables: - - * `T` is the type of the data stored in the resource. - * `K` is the type of the location of the resource, for example `Path` - - This wraps a [`VersionedResource`][neps.state.protocols.VersionedResource] and - additionally provides utility to perform atmoic operations on it using a - [`Locker`][neps.state.protocols.Locker]. - - This is used by [`NePSState`][neps.state.neps_state.NePSState] to manage the state - of trials and other shared resources. - - It consists of 2 main components: - - * A [`VersionedResource`][neps.state.protocols.VersionedResource] to manage the - versioning of the resource. - * A [`Locker`][neps.state.protocols.Locker] to manage the locking of the resource. - - The primary methods to interact with a resource that is behined a `Synced` are: - - * [`synced()`][neps.state.protocols.Synced.synced] to get the data of the resource - after syncing it to it's latest verison. - * [`acquire()`][neps.state.protocols.Synced.acquire] context manager to get latest - version of the data while also mainting a lock on it. This additionally provides - a `put()` operation to put the data back. This can primarily be used to get the - data, perform some mutation on it and then put it back, while not allowing other - workers access to the data. - """ - - LockFailedError: ClassVar = Locker.LockFailedError - VersionedResourceRemovedError: ClassVar = ( - VersionedResource.VersionedResourceRemovedError - ) - VersionMismatchError: ClassVar = VersionedResource.VersionMismatchError - VersionedResourceAlreadyExistsError: ClassVar = ( - VersionedResource.VersionedResourceAlreadyExistsError - ) - VersionedResourceDoesNotExistsError: ClassVar = ( - VersionedResource.VersionedResourceDoesNotExistsError - ) - - _resource: VersionedResource[T, K] - _locker: Locker - - @classmethod - def new( - cls, - *, - locker: Locker, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Create a new Synced resource. - - This will create a new resource if it doesn't exist, otherwise, - if it already exists, it will raise an error. - - Use [`load()`][neps.state.protocols.Synced.load] if you want to load an existing - resource. Use [`new_or_load()`][neps.state.protocols.Synced.new_or_load] if you - want to create a new resource if it doesn't exist, otherwise load an existing - resource. - - Args: - locker: The locker to be used. - data: The data to be stored. - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A new Synced resource. - - Raises: - VersionedResourceAlreadyExistsError: If a versioned resource already exists - at the given location. - """ - with locker.lock(): - vr = VersionedResource.new( - data=data, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - return Synced(_resource=vr, _locker=locker) - - @classmethod - def load( - cls, - *, - locker: Locker, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Load an existing Synced resource. - - This will load an existing resource if it exists, otherwise, it will raise an - error. - - Use [`new()`][neps.state.protocols.Synced.new] if you want to create a new - resource. Use [`new_or_load()`][neps.state.protocols.Synced.new_or_load] if you - want to create a new resource if it doesn't exist, otherwise load an existing - resource. - - Args: - locker: The locker to be used. - location: The location of the resource. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A Synced resource. - - Raises: - VersionedResourceDoesNotExistsError: If no versioned resource exists at - the given location. - """ - with locker.lock(): - return Synced( - _resource=VersionedResource.load( - location=location, - versioner=versioner, - reader_writer=reader_writer, - ), - _locker=locker, - ) - - @classmethod - def new_or_load( - cls, - *, - locker: Locker, - data: T2, - location: K2, - versioner: Versioner, - reader_writer: ReaderWriter[T2, K2], - ) -> Synced[T2, K2]: - """Create a new Synced resource if it doesn't exist, otherwise load it. - - This will create a new resource if it doesn't exist, otherwise, it will load - an existing resource. - - Use [`new()`][neps.state.protocols.Synced.new] if you want to create a new - resource and fail otherwise. Use [`load()`][neps.state.protocols.Synced.load] - if you want to load an existing resource and fail if it doesn't exist. - - Args: - locker: The locker to be used. - data: The data to be stored. - - !!! warning - - This will be ignored if the data already exists. - - location: The location where the data will be stored. - versioner: The versioner to be used. - reader_writer: The reader-writer to be used. - - Returns: - A Synced resource. - """ - try: - return Synced.new( - locker=locker, - data=data, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - except VersionedResourceAlreadyExistsError: - return Synced.load( - locker=locker, - location=location, - versioner=versioner, - reader_writer=reader_writer, - ) - - def synced(self) -> T: - """Get the data of the resource atomically.""" - with self._locker.lock(): - return self._resource.sync_and_get() - - def location(self) -> K: - """Get the location of the resource.""" - return self._resource.location() - - def put(self, data: T) -> None: - """Update the data atomically.""" - with self._locker.lock(): - self._resource.put(data) - - @contextmanager - def acquire(self) -> Iterator[tuple[T, Callable[[T], None]]]: - """Acquire the lock and get the data of the resource. - - This is a context manager that returns the data of the resource and a function - to put the data back. - - !!! note - This is the primary way to get the resource, mutate it and put it back. - Otherwise you likely want [`synced()`][neps.state.protocols.Synced.synced] - or [`put()`][neps.state.protocols.Synced.put]. - - Yields: - A tuple containing the data of the resource and a function to put the data - back. - """ - with self._locker.lock(): - self._resource.sync() - yield self._resource.current(), self._put_unsafe - - def deepcopy(self) -> Self: - """Create a deep copy of the shared resource.""" - return deepcopy(self) - - def _components(self) -> tuple[T, K, Versioner, ReaderWriter[T, K], Locker]: - """Get the components of the shared resource.""" - return ( - self._resource.current(), - self._resource.location(), - self._resource._versioner, - self._resource._reader_writer, - self._locker, - ) - - def _unsynced(self) -> T: - """Get the current data of the resource **without** locking and syncing it.""" - return self._resource.current() - - def _is_stale(self) -> bool: - """Check if the data held currently is not the latest version.""" - return self._resource.is_stale() - - def _is_locked(self) -> bool: - """Check if the resource is locked.""" - return self._locker.is_locked() - - def _put_unsafe(self, data: T) -> None: - """Put the data without checking for staleness or acquiring the lock. - - !!! warning - This should only really be called if you know what you're doing. - """ - self._resource.put(data) diff --git a/neps/status/status.py b/neps/status/status.py index bb68e50d..4b47eaeb 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -9,9 +9,9 @@ import pandas as pd -from neps.state.filebased import load_filebased_neps_state +from neps.state.filebased import FileLocker +from neps.state.neps_state import NePSState from neps.state.trial import Trial -from neps.utils._locker import Locker from neps.utils.types import ConfigID, _ConfigResultForStats if TYPE_CHECKING: @@ -37,9 +37,8 @@ def get_summary_dict( # NOTE: We don't lock the shared state since we are just reading and don't need to # make decisions based on the state - shared_state = load_filebased_neps_state(root_directory) - - trials = shared_state.get_all_trials() + shared_state = NePSState.create_or_load(root_directory, load_only=True) + trials = shared_state.lock_and_read_trials() evaluated: dict[ConfigID, _ConfigResultForStats] = {} @@ -160,7 +159,7 @@ def status( return summary["previous_results"], summary["pending_configs"] -def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, Locker]: +def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, FileLocker]: """Initializes a summary CSV and an associated locker for file access control. Args: @@ -181,7 +180,7 @@ def _initiate_summary_csv(root_directory: str | Path) -> tuple[Path, Path, Locke csv_config_data = summary_csv_directory / "config_data.csv" csv_run_data = summary_csv_directory / "run_status.csv" - csv_locker = Locker(summary_csv_directory / ".csv_lock") + csv_locker = FileLocker(summary_csv_directory / ".csv_lock", poll=2, timeout=600) return ( csv_config_data, @@ -282,7 +281,7 @@ def _get_dataframes_from_summary( def _save_data_to_csv( config_data_file_path: Path, run_data_file_path: Path, - locker: Locker, + locker: FileLocker, config_data_df: pd.DataFrame, run_data_df: pd.DataFrame, ) -> None: @@ -299,7 +298,7 @@ def _save_data_to_csv( config_data_df: The DataFrame containing configuration data. run_data_df: The DataFrame containing additional run data. """ - with locker(poll=2, timeout=600): + with locker.lock(): try: pending_configs = run_data_df.loc["num_pending_configs", "value"] pending_configs_with_worker = run_data_df.loc[ diff --git a/neps/utils/_locker.py b/neps/utils/_locker.py deleted file mode 100644 index f2d430f8..00000000 --- a/neps/utils/_locker.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator -from contextlib import contextmanager -from pathlib import Path -from typing import IO - -import portalocker as pl - -EXCLUSIVE_NONE_BLOCKING = pl.LOCK_EX | pl.LOCK_NB - - -class Locker: - FailedToAcquireLock = pl.exceptions.LockException - - def __init__(self, lock_path: Path): - self.lock_path = lock_path - self.lock_path.touch(exist_ok=True) - - @contextmanager - def try_lock(self) -> Iterator[bool]: - try: - with self.acquire(fail_when_locked=True): - yield True - except self.FailedToAcquireLock: - yield False - - def is_locked(self) -> bool: - with self.try_lock() as acquired_lock: - return not acquired_lock - - @contextmanager - def __call__( - self, - poll: float = 1, - *, - timeout: float | None = None, - fail_when_locked: bool = False, - ) -> Iterator[IO]: - with pl.Lock( - self.lock_path, - check_interval=poll, - timeout=timeout, - flags=EXCLUSIVE_NONE_BLOCKING, - fail_when_locked=fail_when_locked, - ) as fh: - yield fh # We almost never use it but nothing better to yield - - @contextmanager - def acquire( - self, - poll: float = 1.0, - *, - timeout: float | None = None, - fail_when_locked: bool = False, - ) -> Iterator[IO]: - with self( - poll, - timeout=timeout, - fail_when_locked=fail_when_locked, - ) as fh: - yield fh diff --git a/neps/utils/cli.py b/neps/utils/cli.py index cde70357..455fec43 100644 --- a/neps/utils/cli.py +++ b/neps/utils/cli.py @@ -40,10 +40,6 @@ ) from neps.optimizers.base_optimizer import BaseOptimizer from neps.utils.run_args import load_and_return_object -from neps.state.filebased import ( - create_or_load_filebased_neps_state, - load_filebased_neps_state, -) from neps.state.neps_state import NePSState from neps.state.trial import Trial from neps.exceptions import VersionedResourceDoesNotExistError, TrialNotFoundError @@ -140,8 +136,8 @@ def init_config(args: argparse.Namespace) -> None: else: directory = Path(directory) is_new = not directory.exists() - _ = create_or_load_filebased_neps_state( - directory=directory, + _ = NePSState.create_or_load( + path=directory, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( budget=( @@ -335,7 +331,7 @@ def info_config(args: argparse.Namespace) -> None: if neps_state is None: return try: - trial = neps_state.get_trial_by_id(config_id) + trial = neps_state.unsafe_retry_get_trial_by_id(config_id) except TrialNotFoundError: print(f"No trial found with ID {config_id}.") return @@ -381,7 +377,7 @@ def load_neps_errors(args: argparse.Namespace) -> None: neps_state = load_neps_state(directory_path) if neps_state is None: return - errors = neps_state.get_errors() + errors = neps_state.lock_and_get_errors() if not errors.errs: print("No errors found.") @@ -441,7 +437,7 @@ def sample_config(args: argparse.Namespace) -> None: # Sample trials for _ in range(num_configs): try: - trial = neps_state.sample_trial(optimizer, worker_id=worker_id) + trial = neps_state.lock_and_sample_trial(optimizer, worker_id=worker_id) except Exception as e: print(f"Error during configuration sampling: {e}") continue # Skip to the next iteration @@ -491,10 +487,9 @@ def status(args: argparse.Namespace) -> None: summary = get_summary_dict(directory_path, add_details=True) # Calculate the number of trials in different states + trials = neps_state.lock_and_read_trials() evaluating_trials_count = sum( - 1 - for trial in neps_state.get_all_trials().values() - if trial.state.name == "EVALUATING" + 1 for trial in trials.values() if trial.state == Trial.State.EVALUATING ) pending_trials_count = summary["num_pending_configs"] succeeded_trials_count = summary["num_evaluated_configs"] - summary["num_error"] @@ -503,7 +498,7 @@ def status(args: argparse.Namespace) -> None: # Print summary print("NePS Status:") print("-----------------------------") - print(f"Optimizer: {neps_state.optimizer_info().info['searcher_alg']}") + print(f"Optimizer: {neps_state.lock_and_get_optimizer_info().info['searcher_alg']}") print(f"Succeeded Trials: {succeeded_trials_count}") print(f"Failed Trials (Errors): {failed_trials_count}") print(f"Active Trials: {evaluating_trials_count}") @@ -514,9 +509,8 @@ def status(args: argparse.Namespace) -> None: print("-----------------------------") # Retrieve and sort the trials by time_sampled - all_trials = neps_state.get_all_trials() sorted_trials = sorted( - all_trials.values(), key=lambda t: t.metadata.time_sampled, reverse=True + trials.values(), key=lambda t: t.metadata.time_sampled, reverse=True ) # Filter trials based on state @@ -589,7 +583,7 @@ def status(args: argparse.Namespace) -> None: print("\nNo successful trial found.") # Display optimizer information - optimizer_info = neps_state.optimizer_info().info + optimizer_info = neps_state.lock_and_get_optimizer_info().info searcher_name = optimizer_info.get("searcher_name", "N/A") searcher_alg = optimizer_info.get("searcher_alg", "N/A") searcher_args = optimizer_info.get("searcher_args", {}) @@ -631,7 +625,7 @@ def sort_trial_id(trial_id: str) -> List[int]: # Convert each part to an integer for proper numeric sorting return [int(part) for part in parts] - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() sorted_trials = sorted(trials.values(), key=lambda x: sort_trial_id(x.id)) # Compute incumbents @@ -662,10 +656,10 @@ def sort_trial_id(trial_id: str) -> List[int]: print(f"Plot saved to '{plot_path}'.") -def load_neps_state(directory_path: Path) -> Optional[NePSState[Path]]: +def load_neps_state(directory_path: Path) -> Optional[NePSState]: """Load the NePS state with error handling.""" try: - return load_filebased_neps_state(directory_path) + return NePSState.create_or_load(directory_path, load_only=True) except VersionedResourceDoesNotExistError: print(f"Error: No NePS state found in the directory '{directory_path}'.") print("Ensure that the NePS run has been initialized correctly.") @@ -679,7 +673,11 @@ def compute_incumbents(sorted_trials: List[Trial]) -> List[Trial]: best_loss = float("inf") incumbents = [] for trial in sorted_trials: - if trial.report and trial.report.loss < best_loss: + if ( + trial.report is not None + and trial.report.loss is not None + and trial.report.loss < best_loss + ): best_loss = trial.report.loss incumbents.append(trial) return incumbents[::-1] # Reverse for most recent first @@ -1031,7 +1029,7 @@ def handle_report_config(args: argparse.Namespace) -> None: # Load the existing trial by ID try: - trial = neps_state.get_trial_by_id(args.trial_id) + trial = neps_state.unsafe_retry_get_trial_by_id(args.trial_id) if not trial: print(f"No trial found with ID {args.trial_id}") return @@ -1054,7 +1052,7 @@ def handle_report_config(args: argparse.Namespace) -> None: # Update NePS state try: - neps_state.report_trial_evaluation( + neps_state._report_trial_evaluation( trial=trial, report=report, worker_id=args.worker_id ) except Exception as e: diff --git a/neps/utils/common.py b/neps/utils/common.py index 1e735063..3887565e 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -160,7 +160,7 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: if pipeline_directory is not None: # TODO: Hard coded assumption config_id = Path(pipeline_directory).name.split("_", maxsplit=1)[-1] - trial = neps_state.get_trial_by_id(config_id) + trial = neps_state.unsafe_retry_get_trial_by_id(config_id) else: trial = get_in_progress_trial() @@ -169,7 +169,7 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: # Recursively find the initial directory while (prev_trial_id := trial.metadata.previous_trial_id) is not None: - trial = neps_state.get_trial_by_id(prev_trial_id) + trial = neps_state.unsafe_retry_get_trial_by_id(prev_trial_id) initial_dir = trial.metadata.location diff --git a/tests/test_runtime/test_default_report_values.py b/tests/test_runtime/test_default_report_values.py index 265d4c08..d857c69a 100644 --- a/tests/test_runtime/test_default_report_values.py +++ b/tests/test_runtime/test_default_report_values.py @@ -4,7 +4,6 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import create_or_load_filebased_neps_state from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -13,9 +12,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -54,15 +53,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 trial = trials.popitem()[1] assert trial.state == Trial.State.CRASHED @@ -104,15 +103,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_sucess = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] assert trial.state == Trial.State.SUCCESS @@ -152,15 +151,15 @@ def eval_function(*args, **kwargs) -> float: ) worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_sucess = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] assert trial.state == Trial.State.SUCCESS diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index 05cf762a..d341aa2b 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -8,8 +8,6 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.err_dump import SerializedError -from neps.state.filebased import create_or_load_filebased_neps_state from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -18,9 +16,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -61,15 +59,15 @@ def eval_function(*args, **kwargs) -> float: with pytest.raises(WorkerRaiseError): worker.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 def test_worker_raises_when_error_in_other_worker(neps_state: NePSState) -> None: @@ -114,15 +112,15 @@ def evaler(*args, **kwargs) -> float: with pytest.raises(WorkerRaiseError): worker2.run() - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_crashed = sum( trial.state == Trial.State.CRASHED is not None for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 @pytest.mark.parametrize( @@ -184,7 +182,7 @@ def __call__(self, *args, **kwargs) -> float: worker2.run() assert worker2.worker_cumulative_eval_count == 1 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() n_success = sum( trial.state == Trial.State.SUCCESS is not None for trial in trials.values() ) @@ -195,5 +193,5 @@ def __call__(self, *args, **kwargs) -> float: assert n_crashed == 1 assert len(trials) == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 1 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 1 diff --git a/tests/test_runtime/test_stopping_criterion.py b/tests/test_runtime/test_stopping_criterion.py index c73051a9..3e6da7ce 100644 --- a/tests/test_runtime/test_stopping_criterion.py +++ b/tests/test_runtime/test_stopping_criterion.py @@ -5,7 +5,7 @@ from neps.optimizers.random_search.optimizer import RandomSearch from neps.runtime import DefaultWorker from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import create_or_load_filebased_neps_state +from neps.state.neps_state import NePSState from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -14,9 +14,9 @@ @fixture -def neps_state(tmp_path: Path) -> NePSState[Path]: - return create_or_load_filebased_neps_state( - directory=tmp_path / "neps_state", +def neps_state(tmp_path: Path) -> NePSState: + return NePSState.create_or_load( + path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), optimizer_state=OptimizationState(budget=None, shared_state={}), ) @@ -52,10 +52,10 @@ def eval_function(*args, **kwargs) -> float: worker.run() assert worker.worker_cumulative_eval_count == 3 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS assert trial.report is not None @@ -71,8 +71,8 @@ def eval_function(*args, **kwargs) -> float: ) new_worker.run() assert new_worker.worker_cumulative_eval_count == 0 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 def test_worker_evaluations_total_stopping_criterion( @@ -105,10 +105,10 @@ def eval_function(*args, **kwargs) -> float: worker.run() assert worker.worker_cumulative_eval_count == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS @@ -126,10 +126,10 @@ def eval_function(*args, **kwargs) -> float: new_worker.run() assert worker.worker_cumulative_eval_count == 2 - assert neps_state.get_next_pending_trial() is None - assert len(neps_state.get_errors()) == 0 + assert neps_state.lock_and_get_next_pending_trial() is None + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 4 # Now we should have 4 of them for _, trial in trials.items(): assert trial.state == Trial.State.SUCCESS @@ -155,7 +155,7 @@ def test_include_in_progress_evaluations_towards_maximum_with_work_eval_count( ) # We put in one trial as being inprogress - pending_trial = neps_state.sample_trial(optimizer, worker_id="dummy") + pending_trial = neps_state.lock_and_sample_trial(optimizer, worker_id="dummy") pending_trial.set_evaluating(time_started=0.0, worker_id="dummy") neps_state.put_updated_trial(pending_trial) @@ -173,11 +173,11 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count == 1 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 the_pending_trial = trials[pending_trial.id] @@ -225,11 +225,11 @@ def eval_function(*args, **kwargs) -> dict: assert worker.worker_cumulative_eval_count == 2 assert worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 # New worker should now not run anything as the total cost has been reached. @@ -276,11 +276,11 @@ def eval_function(*args, **kwargs) -> dict: assert worker.worker_cumulative_eval_count == 2 assert worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 2 # New worker should also run 2 more trials @@ -295,11 +295,11 @@ def eval_function(*args, **kwargs) -> dict: assert new_worker.worker_cumulative_eval_count == 2 assert new_worker.worker_cumulative_eval_cost == 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 + assert len(neps_state.lock_and_get_errors()) == 0 - trials = neps_state.get_all_trials() + trials = neps_state.lock_and_read_trials() assert len(trials) == 4 # 2 more trials were ran @@ -336,10 +336,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -354,10 +354,10 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count > 0 assert new_worker.worker_cumulative_evaluation_time_seconds <= 2.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker > len_trials_on_first_worker @@ -395,10 +395,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -413,10 +413,10 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count > 0 assert new_worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker > len_trials_on_first_worker @@ -454,10 +454,10 @@ def eval_function(*args, **kwargs) -> float: assert worker.worker_cumulative_eval_count > 0 assert worker.worker_cumulative_evaluation_time_seconds <= 1.0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_first_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_first_worker = len(neps_state.lock_and_read_trials()) # New worker should also run some trials more trials new_worker = DefaultWorker.new( @@ -472,8 +472,8 @@ def eval_function(*args, **kwargs) -> float: assert new_worker.worker_cumulative_eval_count == 0 assert new_worker.worker_cumulative_evaluation_time_seconds == 0 assert ( - neps_state.get_next_pending_trial() is None + neps_state.lock_and_get_next_pending_trial() is None ) # should have no pending trials to be picked up - assert len(neps_state.get_errors()) == 0 - len_trials_on_second_worker = len(neps_state.get_all_trials()) + assert len(neps_state.lock_and_get_errors()) == 0 + len_trials_on_second_worker = len(neps_state.lock_and_read_trials()) assert len_trials_on_second_worker == len_trials_on_first_worker diff --git a/tests/test_state/test_filebased_neps_state.py b/tests/test_state/test_filebased_neps_state.py index 02f5a52c..87085639 100644 --- a/tests/test_state/test_filebased_neps_state.py +++ b/tests/test_state/test_filebased_neps_state.py @@ -6,10 +6,7 @@ from typing import Any from neps.exceptions import NePSError, TrialNotFoundError from neps.state.err_dump import ErrDump -from neps.state.filebased import ( - create_or_load_filebased_neps_state, - load_filebased_neps_state, -) +from neps.state.neps_state import NePSState import pytest from pytest_cases import fixture, parametrize @@ -38,21 +35,21 @@ def test_create_with_new_filebased_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) - assert neps_state.optimizer_info() == optimizer_info - assert neps_state.optimizer_state() == optimizer_state + assert neps_state.lock_and_get_optimizer_info() == optimizer_info + assert neps_state.lock_and_get_optimizer_state() == optimizer_state assert neps_state.all_trial_ids() == [] - assert neps_state.get_all_trials() == {} - assert neps_state.get_errors() == ErrDump(errs=[]) - assert neps_state.get_next_pending_trial() is None - assert neps_state.get_next_pending_trial(n=10) == [] + assert neps_state.lock_and_read_trials() == {} + assert neps_state.lock_and_get_errors() == ErrDump(errs=[]) + assert neps_state.lock_and_get_next_pending_trial() is None + assert neps_state.lock_and_get_next_pending_trial(n=10) == [] with pytest.raises(TrialNotFoundError): - assert neps_state.get_trial_by_id("1") + assert neps_state.lock_and_get_trial_by_id("1") def test_create_or_load_with_load_filebased_neps_state( @@ -61,8 +58,8 @@ def test_create_or_load_with_load_filebased_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) @@ -74,8 +71,8 @@ def test_create_or_load_with_load_filebased_neps_state( budget=BudgetInfo(max_cost_budget=20, used_cost_budget=10), shared_state={"c": "d"}, ) - neps_state2 = create_or_load_filebased_neps_state( - directory=new_path, + neps_state2 = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=different_state, ) @@ -88,13 +85,13 @@ def test_load_on_existing_neps_state( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - neps_state = create_or_load_filebased_neps_state( - directory=new_path, + neps_state = NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) - neps_state2 = load_filebased_neps_state(directory=new_path) + neps_state2 = NePSState.create_or_load(path=new_path, load_only=True) assert neps_state == neps_state2 @@ -104,15 +101,15 @@ def test_new_or_load_on_existing_neps_state_with_different_optimizer_info( optimizer_state: OptimizationState, ) -> None: new_path = tmp_path / "neps_state" - create_or_load_filebased_neps_state( - directory=new_path, + NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=optimizer_state, ) with pytest.raises(NePSError): - create_or_load_filebased_neps_state( - directory=new_path, + NePSState.create_or_load( + path=new_path, optimizer_info=OptimizerInfo({"e": "f"}), optimizer_state=optimizer_state, ) diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index c64cb64e..78b3213b 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -15,9 +15,7 @@ Categorical, ) from neps.search_spaces.search_space import SearchSpace -from neps.state.filebased import ( - create_or_load_filebased_neps_state, -) +from neps.state.neps_state import NePSState from pytest_cases import fixture, parametrize, parametrize_with_cases, case from neps.state.neps_state import NePSState @@ -156,8 +154,8 @@ def case_neps_state_filebased( shared_state: dict[str, Any], ) -> NePSState: new_path = tmp_path / "neps_state" - return create_or_load_filebased_neps_state( - directory=new_path, + return NePSState.create_or_load( + path=new_path, optimizer_info=optimizer_info, optimizer_state=OptimizationState(budget=budget, shared_state=shared_state), ) @@ -169,15 +167,15 @@ def test_sample_trial( optimizer_and_key: tuple[BaseOptimizer, str], ) -> None: optimizer, key = optimizer_and_key - if key in REQUIRES_COST and neps_state.optimizer_state().budget is None: + if key in REQUIRES_COST and neps_state.lock_and_get_optimizer_state().budget is None: pytest.xfail(f"{key} requires a cost budget") - assert neps_state.get_all_trials() == {} - assert neps_state.get_next_pending_trial() is None - assert neps_state.get_next_pending_trial(n=10) == [] + assert neps_state.lock_and_read_trials() == {} + assert neps_state.lock_and_get_next_pending_trial() is None + assert neps_state.lock_and_get_next_pending_trial(n=10) == [] assert neps_state.all_trial_ids() == [] - trial1 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") + trial1 = neps_state.lock_and_sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): assert k in optimizer.pipeline_space.hyperparameters assert v is not None, f"'{k}' is None in {trial1.config}" @@ -186,19 +184,19 @@ def test_sample_trial( # precise, we need to introduce a sleep -_- time.sleep(0.1) - assert neps_state.get_all_trials() == {trial1.id: trial1} - assert neps_state.get_next_pending_trial() == trial1 - assert neps_state.get_next_pending_trial(n=10) == [trial1] + assert neps_state.lock_and_read_trials() == {trial1.id: trial1} + assert neps_state.lock_and_get_next_pending_trial() == trial1 + assert neps_state.lock_and_get_next_pending_trial(n=10) == [trial1] assert neps_state.all_trial_ids() == [trial1.id] - trial2 = neps_state.sample_trial(optimizer=optimizer, worker_id="1") + trial2 = neps_state.lock_and_sample_trial(optimizer=optimizer, worker_id="1") for k, v in trial1.config.items(): assert k in optimizer.pipeline_space.hyperparameters assert v is not None, f"'{k}' is None in {trial1.config}" assert trial1 != trial2 - assert neps_state.get_all_trials() == {trial1.id: trial1, trial2.id: trial2} - assert neps_state.get_next_pending_trial() == trial1 - assert neps_state.get_next_pending_trial(n=10) == [trial1, trial2] + assert neps_state.lock_and_read_trials() == {trial1.id: trial1, trial2.id: trial2} + assert neps_state.lock_and_get_next_pending_trial() == trial1 + assert neps_state.lock_and_get_next_pending_trial(n=10) == [trial1, trial2] assert sorted(neps_state.all_trial_ids()) == [trial1.id, trial2.id] diff --git a/tests/test_state/test_synced.py b/tests/test_state/test_synced.py deleted file mode 100644 index 6294db37..00000000 --- a/tests/test_state/test_synced.py +++ /dev/null @@ -1,429 +0,0 @@ -import copy -import random - -from pytest_cases import parametrize, parametrize_with_cases, case -import numpy as np -from neps.state.err_dump import ErrDump, SerializableTrialError -from neps.state.filebased import ( - ReaderWriterErrDump, - ReaderWriterOptimizationState, - ReaderWriterOptimizerInfo, - ReaderWriterSeedSnapshot, - ReaderWriterTrial, - FileVersioner, - FileLocker, -) -from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -import pytest -from typing import Any, Callable -from pathlib import Path -from neps.state import SeedSnapshot, Synced, Trial - - -@case -def case_trial_1(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - location="", - config={"a": "b"}, - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - - def _update(trial: Trial) -> None: - trial.set_submitted(time_submitted=1) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_2(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - location="", - config={"a": "b"}, - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - - def _update(trial: Trial) -> None: - trial.set_evaluating(time_started=2, worker_id="1") - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_3(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id="1") - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=1, - cost=1, - extra={"hi": [1, 2, 3]}, - learning_curve=[1], - report_as="success", - evaluation_duration=1, - err=None, - tb=None, - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_4(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id="1") - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - report_as="failed", - learning_curve=None, - evaluation_duration=2, - err=None, - tb=None, - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_5(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - - def _update(trial: Trial) -> None: - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - learning_curve=None, - evaluation_duration=2, - report_as="failed", - err=ValueError("hi"), - tb="something something traceback", - ) - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_6(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - - def _update(trial: Trial) -> None: - trial.set_corrupted() - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_trial_7(tmp_path: Path) -> tuple[Synced[Trial, Path], Callable[[Trial], None]]: - trial_id = "1" - trial = Trial.new( - trial_id=trial_id, - config={"a": "b"}, - location="", - time_sampled=0, - previous_trial=None, - previous_trial_location=None, - worker_id=0, - ) - trial.set_submitted(time_submitted=1) - trial.set_evaluating(time_started=2, worker_id=1) - trial.set_complete( - time_end=3, - loss=np.nan, - cost=np.inf, - extra={"hi": [1, 2, 3]}, - learning_curve=[1, 2, 3], - report_as="failed", - evaluation_duration=2, - err=ValueError("hi"), - tb="something something traceback", - ) - - def _update(trial: Trial) -> None: - trial.reset() - - x = Synced.new( - data=trial, - location=tmp_path / "1", - locker=FileLocker(lock_path=tmp_path / "1" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "1" / ".version"), - reader_writer=ReaderWriterTrial(), - ) - return x, _update - - -@case -def case_seed_snapshot( - tmp_path: Path, -) -> tuple[Synced[SeedSnapshot, Path], Callable[[SeedSnapshot], None]]: - seed = SeedSnapshot.new_capture() - - def _update(seed: SeedSnapshot) -> None: - random.randint(0, 100) - seed.recapture() - - x = Synced.new( - data=seed, - location=tmp_path / "seeds", - locker=FileLocker(lock_path=tmp_path / "seeds" / ".lock", poll=0.1, timeout=None), - versioner=FileVersioner(version_file=tmp_path / "seeds" / ".version"), - reader_writer=ReaderWriterSeedSnapshot(), - ) - return x, _update - - -@case -@parametrize( - "err", - [ - None, - SerializableTrialError( - trial_id="1", - worker_id="2", - err_type="ValueError", - err="hi", - tb="traceback\nmore", - ), - ], -) -def case_err_dump( - tmp_path: Path, - err: None | SerializableTrialError, -) -> tuple[Synced[ErrDump, Path], Callable[[ErrDump], None]]: - err_dump = ErrDump() if err is None else ErrDump(errs=[err]) - - def _update(err_dump: ErrDump) -> None: - new_err = SerializableTrialError( - trial_id="2", - worker_id="2", - err_type="RuntimeError", - err="hi", - tb="traceback\nless", - ) - err_dump.append(new_err) - - x = Synced.new( - data=err_dump, - location=tmp_path / "err_dump", - locker=FileLocker( - lock_path=tmp_path / "err_dump" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "err_dump" / ".version"), - reader_writer=ReaderWriterErrDump("all"), - ) - return x, _update - - -@case -def case_optimizer_info( - tmp_path: Path, -) -> tuple[Synced[OptimizerInfo, Path], Callable[[OptimizerInfo], None]]: - optimizer_info = OptimizerInfo(info={"a": "b"}) - - def _update(optimizer_info: OptimizerInfo) -> None: - optimizer_info.info["b"] = "c" # type: ignore # NOTE: We shouldn't be mutating but anywho... - - x = Synced.new( - data=optimizer_info, - location=tmp_path / "optimizer_info", - locker=FileLocker( - lock_path=tmp_path / "optimizer_info" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "optimizer_info" / ".version"), - reader_writer=ReaderWriterOptimizerInfo(), - ) - return x, _update - - -@case -@pytest.mark.parametrize( - "budget", (None, BudgetInfo(max_cost_budget=10, used_cost_budget=0)) -) -@pytest.mark.parametrize("shared_state", ({}, {"a": "b"})) -def case_optimization_state( - tmp_path: Path, - budget: BudgetInfo | None, - shared_state: dict[str, Any], -) -> tuple[Synced[OptimizationState, Path], Callable[[OptimizationState], None]]: - optimization_state = OptimizationState(budget=budget, shared_state=shared_state) - - def _update(optimization_state: OptimizationState) -> None: - optimization_state.shared_state["a"] = "c" # type: ignore # NOTE: We shouldn't be mutating but anywho... - optimization_state.budget = BudgetInfo(max_cost_budget=10, used_cost_budget=5) - - x = Synced.new( - data=optimization_state, - location=tmp_path / "optimizer_info", - locker=FileLocker( - lock_path=tmp_path / "optimizer_info" / ".lock", poll=0.1, timeout=None - ), - versioner=FileVersioner(version_file=tmp_path / "optimizer_info" / ".version"), - reader_writer=ReaderWriterOptimizationState(), - ) - return x, _update - - -@parametrize_with_cases("shared, update", cases=".") -def test_initial_state(shared: Synced, update: Callable) -> None: - assert shared._is_locked() == False - assert shared._is_stale() == False - assert shared._unsynced() == shared.synced() - - -@parametrize_with_cases("shared, update", cases=".") -def test_put_updates_current_data_and_is_not_stale( - shared: Synced, update: Callable -) -> None: - current_data = shared._unsynced() - - new_data = copy.deepcopy(current_data) - update(new_data) - assert new_data != current_data - - shared.put(new_data) - assert shared._unsynced() == new_data - assert shared._is_stale() == False - assert shared._is_locked() == False - - -@parametrize_with_cases("shared1, update", cases=".") -def test_share_synced_update_and_put(shared1: Synced, update: Callable) -> None: - shared2 = shared1.deepcopy() - assert shared1 == shared2 - assert not shared1._is_locked() - assert not shared2._is_locked() - - with shared2.acquire() as (data2, put2): - assert shared1._is_locked() - assert shared2._is_locked() - update(data2) - put2(data2) - - assert not shared1._is_locked() - assert not shared2._is_locked() - - assert shared1 != shared2 - assert shared1._unsynced() != shared2._unsynced() - assert shared1._is_stale() - - shared1.synced() - assert not shared1._is_stale() - assert not shared2._is_stale() - assert shared1._unsynced() == shared2._unsynced() - - -@parametrize_with_cases("shared, update", cases=".") -def test_shared_new_fails_if_done_on_existing_resource( - shared: Synced, update: Callable -) -> None: - data, location, versioner, rw, lock = shared._components() - with pytest.raises(Synced.VersionedResourceAlreadyExistsError): - Synced.new( - data=data, - location=location, - versioner=versioner, - reader_writer=rw, - locker=lock, - ) From 3a286b3ddb0626d76e7bdb7964eea0cc55a7e378 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 18:35:07 +0100 Subject: [PATCH 28/56] fix: Use context manager for sample_trial --- neps/runtime.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 308a2560..5cede349 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -417,30 +417,30 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 logger.info(should_stop) break - pending_trials = [ - trial - for trial in trials.values() - if trial.state == Trial.State.PENDING - ] - if len(pending_trials) > 0: - earliest_pending = sorted( - pending_trials, - key=lambda t: t.metadata.time_sampled, - )[0] - earliest_pending.set_evaluating( - time_started=time.time(), - worker_id=self.worker_id, - ) - self.state._trials.update_trial(earliest_pending) - trial_to_eval = earliest_pending - else: - sampled_trial = self.state._sample_trial( - optimizer=self.optimizer, - worker_id=self.worker_id, - ) - trial_to_eval = sampled_trial + pending_trials = [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ] + if len(pending_trials) > 0: + earliest_pending = sorted( + pending_trials, + key=lambda t: t.metadata.time_sampled, + )[0] + earliest_pending.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) + self.state._trials.update_trial(earliest_pending) + trial_to_eval = earliest_pending + else: + sampled_trial = self.state._sample_trial( + optimizer=self.optimizer, + worker_id=self.worker_id, + ) + trial_to_eval = sampled_trial - _repeated_fail_get_next_trial_count = 0 + _repeated_fail_get_next_trial_count = 0 except Exception as e: _repeated_fail_get_next_trial_count += 1 logger.debug( From 4b873f09ea18ef2e32e61131961bcb8619c3d024 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 19:34:50 +0100 Subject: [PATCH 29/56] fix: Passing loaded trials to sample --- neps/runtime.py | 11 +++++++---- neps/state/neps_state.py | 14 +------------- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 5cede349..f70f6904 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -403,8 +403,9 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # needs to be in locked in-step with sampling try: # If there are no global stopping criterion, we can no just return early. - with self.state.lock_for_sampling(): - trials = self.state._trials.latest() + with self.state._state_lock.lock(): + with self.state._trial_lock.lock(): + trials = self.state._trials.latest() requires_checking_global_stopping_criterion = ( self.settings.max_evaluations_total is not None @@ -431,12 +432,14 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 time_started=time.time(), worker_id=self.worker_id, ) - self.state._trials.update_trial(earliest_pending) + with self.state._trial_lock.lock(): + self.state._trials.update_trial(earliest_pending) trial_to_eval = earliest_pending else: sampled_trial = self.state._sample_trial( optimizer=self.optimizer, worker_id=self.worker_id, + _trials=trials, ) trial_to_eval = sampled_trial @@ -530,7 +533,7 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # We do not retry this, as if some other worker has # managed to manipulate this trial in the meantime, # then something has gone wrong - with self.state.lock_trials(): + with self.state._trial_lock.lock(): self.state._report_trial_evaluation( trial=evaluated_trial, report=report, diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 39dc5265..26e3acfb 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -225,18 +225,6 @@ class NePSState: _err_lock: FileLocker = field(repr=False) _shared_errors: VersionedResource[ErrDump] = field(repr=False) - @contextmanager - def lock_for_sampling(self) -> Iterator[None]: - """Acquire the state lock and trials lock.""" - with self._state_lock.lock(), self._trial_lock.lock(): - yield - - @contextmanager - def lock_trials(self) -> Iterator[None]: - """Acquire the state lock.""" - with self._trial_lock.lock(): - yield - def lock_and_read_trials(self) -> dict[str, Trial]: """Acquire the state lock and read the trials.""" with self._trial_lock.lock(): @@ -244,7 +232,7 @@ def lock_and_read_trials(self) -> dict[str, Trial]: def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: """Acquire the state lock and sample a trial.""" - with self.lock_for_sampling(): + with self._state_lock.lock(), self._trial_lock.lock(): return self._sample_trial(optimizer, worker_id=worker_id) def lock_and_report_trial_evaluation( From a44f5511bf0f5f88def32bdceaa46409cd650762 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 19:35:03 +0100 Subject: [PATCH 30/56] fix: Passing loaded trials to sample (2) --- neps/state/neps_state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 26e3acfb..f1a9a03d 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -13,8 +13,7 @@ import logging import pickle import time -from collections.abc import Callable, Iterator -from contextlib import contextmanager +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import ( From bb18db7cd925a2e38cb863a441edc4cd9ae4355b Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 21:06:37 +0100 Subject: [PATCH 31/56] optim: reading/loading of items --- neps/runtime.py | 167 +++++++++++++++++++-------------------- neps/state/filebased.py | 98 +++++++++++++++++------ neps/state/neps_state.py | 95 +++++++++++----------- neps/status/status.py | 7 +- neps/utils/files.py | 12 ++- 5 files changed, 221 insertions(+), 158 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index f70f6904..9051ece9 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -23,12 +23,11 @@ LINUX_FILELOCK_FUNCTION, MAX_RETRIES_CREATE_LOAD_STATE, MAX_RETRIES_GET_NEXT_TRIAL, - MAX_RETRIES_SET_EVALUATING, MAX_RETRIES_WORKER_CHECK_SHOULD_STOP, ) from neps.exceptions import ( NePSError, - VersionMismatchError, + TrialAlreadyExistsError, WorkerFailedToGetPendingTrialsError, WorkerRaiseError, ) @@ -340,9 +339,84 @@ def _check_global_stopping_criterion( return False + @property + def _requires_global_stopping_criterion(self) -> bool: + return ( + self.settings.max_evaluations_total is not None + or self.settings.max_cost_total is not None + or self.settings.max_evaluation_time_total_seconds is not None + ) + + def _get_next_trial(self) -> Trial | Literal["break"]: + # If there are no global stopping criterion, we can no just return early. + with self.state._state_lock.lock(): + # With the trial lock, we'll load everything in, if we have a pending + # config, use that and return. + with self.state._trial_lock.lock(): + trials = self.state._trials.latest() + + if self._requires_global_stopping_criterion: + should_stop = self._check_global_stopping_criterion(trials) + if should_stop is not False: + logger.info(should_stop) + return "break" + + pending_trials = [ + trial + for trial in trials.values() + if trial.state == Trial.State.PENDING + ] + + if len(pending_trials) > 0: + earliest_pending = sorted( + pending_trials, + key=lambda t: t.metadata.time_sampled, + )[0] + earliest_pending.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) + self.state._trials.update_trial( + earliest_pending, + hints=["metadata", "state"], + ) + return earliest_pending + + # Otherwise, we release the trial lock while sampling + sampled_trial = self.state._sample_trial( + optimizer=self.optimizer, + worker_id=self.worker_id, + trials=trials, + ) + + with self.state._trial_lock.lock(): + try: + self.state._trials.new_trial(sampled_trial) + return sampled_trial + except TrialAlreadyExistsError as e: + if sampled_trial.id in trials: + logger.warning( + "The new sampled trial was given an id of '%s', yet this already" + " exists in the loaded in trials given to the optimizer. This" + " indicates a bug with the optimizers allocation of ids.", + sampled_trial.id, + ) + else: + logger.warning( + "The new sampled trial was given an id of '%s', which is not one" + " that was loaded in by the optimizer. This indicates that" + " configuration '%s' was put on disk during the time that this" + " worker had the optimizer state lock OR that after obtaining the" + " optimizer state lock, somehow this configuration failed to be" + " loaded in and passed to the optimizer.", + sampled_trial.id, + sampled_trial.id, + ) + raise e + # Forgive me lord, for I have sinned, this function is atrocious but complicated # due to locking. - def run(self) -> None: # noqa: C901, PLR0915, PLR0912 + def run(self) -> None: # noqa: C901, PLR0915 """Run the worker. Will keep running until one of the criterion defined by the `WorkerSettings` @@ -356,7 +430,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 _error_from_evaluation: Exception | None = None _repeated_fail_get_next_trial_count = 0 - n_failed_set_trial_state = 0 n_repeated_failed_check_should_stop = 0 while True: try: @@ -400,50 +473,13 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 # From here, we now begin sampling or getting the next pending trial. # As the global stopping criterion requires us to check all trials, and - # needs to be in locked in-step with sampling + # needs to be in locked in-step with sampling and is done inside + # _get_next_trial try: - # If there are no global stopping criterion, we can no just return early. - with self.state._state_lock.lock(): - with self.state._trial_lock.lock(): - trials = self.state._trials.latest() - - requires_checking_global_stopping_criterion = ( - self.settings.max_evaluations_total is not None - or self.settings.max_cost_total is not None - or self.settings.max_evaluation_time_total_seconds is not None - ) - if requires_checking_global_stopping_criterion: - should_stop = self._check_global_stopping_criterion(trials) - if should_stop is not False: - logger.info(should_stop) - break - - pending_trials = [ - trial - for trial in trials.values() - if trial.state == Trial.State.PENDING - ] - if len(pending_trials) > 0: - earliest_pending = sorted( - pending_trials, - key=lambda t: t.metadata.time_sampled, - )[0] - earliest_pending.set_evaluating( - time_started=time.time(), - worker_id=self.worker_id, - ) - with self.state._trial_lock.lock(): - self.state._trials.update_trial(earliest_pending) - trial_to_eval = earliest_pending - else: - sampled_trial = self.state._sample_trial( - optimizer=self.optimizer, - worker_id=self.worker_id, - _trials=trials, - ) - trial_to_eval = sampled_trial - - _repeated_fail_get_next_trial_count = 0 + trial_to_eval = self._get_next_trial() + if trial_to_eval == "break": + break + _repeated_fail_get_next_trial_count = 0 except Exception as e: _repeated_fail_get_next_trial_count += 1 logger.debug( @@ -452,7 +488,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 exc_info=True, ) time.sleep(1) # Help stagger retries - # NOTE: This is to prevent any infinite loops if we can't get a trial if _repeated_fail_get_next_trial_count >= MAX_RETRIES_GET_NEXT_TRIAL: raise WorkerFailedToGetPendingTrialsError( @@ -463,42 +498,6 @@ def run(self) -> None: # noqa: C901, PLR0915, PLR0912 continue - # If we can't set this working to evaluating, then just retry the loop - try: - n_failed_set_trial_state = 0 - except VersionMismatchError: - n_failed_set_trial_state += 1 - logger.debug( - "Another worker has managed to change trial '%s'" - " while this worker '%s' was trying to set it to" - " evaluating. This is fine and likely means the other worker is" - " evaluating it, this worker will attempt to sample new trial.", - trial_to_eval.id, - self.worker_id, - exc_info=True, - ) - time.sleep(1) # Help stagger retries - except Exception: - n_failed_set_trial_state += 1 - logger.error( - "Unexpected error from worker '%s' trying to set trial" - " '%' to evaluating.", - self.worker_id, - trial_to_eval.id, - exc_info=True, - ) - time.sleep(1) # Help stagger retries - - # NOTE: This is to prevent infinite looping if it somehow keeps getting - # the same trial and can't set it to evaluating. - if n_failed_set_trial_state != 0: - if n_failed_set_trial_state >= MAX_RETRIES_SET_EVALUATING: - raise WorkerFailedToGetPendingTrialsError( - f"Worker {self.worker_id} failed to set trial to evaluating" - f" {MAX_RETRIES_SET_EVALUATING} times in a row. Bailing!" - ) - continue - # We (this worker) has managed to set it to evaluating, now we can evaluate it with _set_global_trial(trial_to_eval): evaluated_trial, report = evaluate_trial( diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 3658d511..f6324860 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path -from typing import ClassVar, Final, TypeVar +from typing import ClassVar, Final, Literal, TypeAlias, TypeVar import numpy as np import portalocker as pl @@ -25,15 +25,23 @@ K = TypeVar("K") T = TypeVar("T") +TrialWriteHint: TypeAlias = Literal["metadata", "report", "state", "config"] + @dataclass class ReaderWriterTrial: """ReaderWriter for Trial objects.""" + # Report and config are kept as yaml since they are most likely to be + # read CONFIG_FILENAME = "config.yaml" - METADATA_FILENAME = "metadata.yaml" - STATE_FILENAME = "state.txt" REPORT_FILENAME = "report.yaml" + + # Metadata is put as json as it's more likely to be machine read and + # is much faster. + METADATA_FILENAME = "metadata.json" + + STATE_FILENAME = "state.txt" PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" @classmethod @@ -43,9 +51,12 @@ def read(cls, directory: Path) -> Trial: state_path = directory / cls.STATE_FILENAME report_path = directory / cls.REPORT_FILENAME + with metadata_path.open("r") as f: + metadata = json.load(f) + return Trial( config=deserialize(config_path), - metadata=Trial.MetaData(**deserialize(metadata_path)), + metadata=Trial.MetaData(**metadata), state=Trial.State(state_path.read_text(encoding="utf-8").strip()), report=( Trial.Report(**deserialize(report_path)) if report_path.exists() else None @@ -53,22 +64,58 @@ def read(cls, directory: Path) -> Trial: ) @classmethod - def write(cls, trial: Trial, directory: Path) -> None: + def write( + cls, + trial: Trial, + directory: Path, + *, + hints: list[TrialWriteHint] | TrialWriteHint | None = None, + ) -> None: config_path = directory / cls.CONFIG_FILENAME metadata_path = directory / cls.METADATA_FILENAME state_path = directory / cls.STATE_FILENAME - serialize(trial.config, config_path) - serialize(asdict(trial.metadata), metadata_path) - state_path.write_text(trial.state.value, encoding="utf-8") - - if trial.metadata.previous_trial_id is not None: - previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME - previous_trial_path.write_text(trial.metadata.previous_trial_id) - - if trial.report is not None: - report_path = directory / cls.REPORT_FILENAME - serialize(asdict(trial.report), report_path) + if isinstance(hints, str): + match hints: + case "config": + serialize(trial.config, config_path) + case "metadata": + with metadata_path.open("w") as f: + json.dump(asdict(trial.metadata), f) + + if trial.metadata.previous_trial_id is not None: + previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME + previous_trial_path.write_text(trial.metadata.previous_trial_id) + case "report": + if trial.report is None: + raise ValueError( + "Cannot write report 'hint' when report is None." + ) + + report_path = directory / cls.REPORT_FILENAME + serialize(asdict(trial.report), report_path) + case "state": + state_path.write_text(trial.state.value, encoding="utf-8") + case _: + raise ValueError(f"Invalid hint: {hints}") + elif hints is None: + # We don't know, write everything + serialize(trial.config, config_path) + with metadata_path.open("w") as f: + json.dump(asdict(trial.metadata), f) + + if trial.metadata.previous_trial_id is not None: + previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME + previous_trial_path.write_text(trial.metadata.previous_trial_id) + + state_path.write_text(trial.state.value, encoding="utf-8") + + if trial.report is not None: + report_path = directory / cls.REPORT_FILENAME + serialize(asdict(trial.report), report_path) + else: + for hint in hints: + cls.write(trial, directory, hints=hint) TrialReaderWriter: Final = ReaderWriterTrial() @@ -155,7 +202,9 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: "py_rng_version": py_rng_version, "py_guass_next": py_guass_next, } - serialize(seed_info, seedinfo_path) + with seedinfo_path.open("w") as f: + json.dump(seed_info, f) + np_rng_state = snapshot.np_rng[1] np_rng_state.tofile(np_rng_path) @@ -195,23 +244,24 @@ def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: class ReaderWriterOptimizationState: """ReaderWriter for OptimizationState objects.""" - STATE_FILE_NAME: ClassVar = "state.yaml" + STATE_FILE_NAME: ClassVar = "state.json" @classmethod def read(cls, directory: Path) -> OptimizationState: state_path = directory / cls.STATE_FILE_NAME - state = deserialize(state_path) + with state_path.open("r") as f: + state = json.load(f) + + shared_state = state.get("shared_state") or {} budget_info = state.get("budget") budget = BudgetInfo(**budget_info) if budget_info is not None else None - return OptimizationState( - shared_state=state.get("shared_state") or {}, - budget=budget, - ) + return OptimizationState(shared_state=shared_state, budget=budget) @classmethod def write(cls, info: OptimizationState, directory: Path) -> None: info_path = directory / cls.STATE_FILE_NAME - serialize(asdict(info), info_path) + with info_path.open("w") as f: + json.dump(asdict(info), f) @dataclass diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index f1a9a03d..67f5ee20 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -41,6 +41,7 @@ ReaderWriterOptimizerInfo, ReaderWriterSeedSnapshot, TrialReaderWriter, + TrialWriteHint, ) from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot @@ -76,7 +77,8 @@ def make_sha() -> Version: class TrialRepo: directory: Path version_file: Path - cache: dict[str, tuple[Trial, Version]] = field(default_factory=dict) + trial_cache: dict[str, Trial] = field(default_factory=dict) + versions: dict[str, Version] = field(default_factory=dict) def __post_init__(self) -> None: self.directory.mkdir(parents=True, exist_ok=True) @@ -95,35 +97,40 @@ def latest(self) -> dict[str, Trial]: with self.version_file.open("rb") as f: versions_on_disk = pickle.load(f) # noqa: S301 - stale = { + stale: dict[str, Version] = { k: v for k, v in versions_on_disk.items() - if self.cache.get(k, (None, "__not_found__")) != v + if self.versions.get(k, "__not_found__") != v } - for trial_id, disk_version in stale.items(): + for trial_id, loaded_version in stale.items(): loaded_trial = self.load_trial_from_disk(trial_id) - self.cache[trial_id] = (loaded_trial, disk_version) + self.trial_cache[trial_id] = loaded_trial + self.versions[trial_id] = loaded_version - return {k: v[0] for k, v in self.cache.items()} + return self.trial_cache def new_trial(self, trial: Trial) -> None: config_path = self.directory / f"config_{trial.id}" if config_path.exists(): raise TrialAlreadyExistsError(trial.id, config_path) config_path.mkdir(parents=True, exist_ok=True) - self.update_trial(trial) + self.update_trial(trial, hints=None) - def update_trial(self, trial: Trial) -> None: + def update_trial( + self, trial: Trial, *, hints: list[TrialWriteHint] | TrialWriteHint | None = None + ) -> None: + TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=hints) + self.trial_cache[trial.id] = trial new_version = make_sha() - TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}") - self.cache[trial.id] = (trial, new_version) + self.versions[trial.id] = new_version + self._write_version_file() - def write_version_file(self) -> None: + def _write_version_file(self) -> None: with self.version_file.open("wb") as f: - pickle.dump({k: v[1] for k, v in self.cache.items()}, f) + pickle.dump(self.versions, f) def trials_in_memory(self) -> dict[str, Trial]: - return {k: v[0] for k, v in self.cache.items()} + return self.trial_cache def load_trial_from_disk(self, trial_id: str) -> Trial: config_path = self.directory / f"config_{trial_id}" @@ -231,8 +238,16 @@ def lock_and_read_trials(self) -> dict[str, Trial]: def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: """Acquire the state lock and sample a trial.""" - with self._state_lock.lock(), self._trial_lock.lock(): - return self._sample_trial(optimizer, worker_id=worker_id) + with self._state_lock.lock(): + with self._trial_lock.lock(): + trials = self._trials.latest() + + trial = self._sample_trial(optimizer, trials=trials, worker_id=worker_id) + + with self._trial_lock.lock(): + self._trials.new_trial(trial) + + return trial def lock_and_report_trial_evaluation( self, @@ -250,8 +265,8 @@ def _sample_trial( optimizer: BaseOptimizer, *, worker_id: str, + trials: dict[str, Trial], _sample_hooks: list[Callable] | None = None, - _trials: dict[str, Trial] | None = None, ) -> Trial: """Sample a new trial from the optimizer. @@ -262,12 +277,12 @@ def _sample_trial( Args: optimizer: The optimizer to sample the trial from. worker_id: The worker that is sampling the trial. + trials: The current trials. _sample_hooks: A list of hooks to apply to the optimizer before sampling. Returns: The new trial. """ - trials = self._trials.latest() if _trials is None else _trials seed_state = self._seed_snapshot.latest() opt_state = self._optimizer_state.latest() @@ -325,30 +340,6 @@ def _sample_trial( time_sampled=time.time(), worker_id=worker_id, ) - try: - self._trials.new_trial(trial) - self._trials.write_version_file() - except TrialAlreadyExistsError as e: - if sampled_config.id in trials: - logger.warning( - "The new sampled trial was given an id of '%s', yet this already" - " exists in the loaded in trials given to the optimizer. This" - " indicates a bug with the optimizers allocation of ids.", - sampled_config.id, - ) - else: - logger.warning( - "The new sampled trial was given an id of '%s', which is not one" - " that was loaded in by the optimizer. This indicates that" - " configuration '%s' was put on disk during the time that this" - " worker had the optimizer state lock OR that after obtaining the" - " optimizer state lock, somehow this configuration failed to be" - " loaded in and passed to the optimizer.", - sampled_config.id, - sampled_config.id, - ) - raise e - seed_state.recapture() self._seed_snapshot.update(seed_state) self._optimizer_state.update( @@ -375,8 +366,7 @@ def _report_trial_evaluation( """ # IMPORTANT: We need to attach the report to the trial before updating the things. trial.report = report - self._trials.update_trial(trial) - self._trials.write_version_file() + self._trials.update_trial(trial, hints=["report", "metadata", "state"]) logger.debug("Updated trial '%s' with status '%s'", trial.id, trial.state) if report.err is not None: @@ -435,11 +425,22 @@ def unsafe_retry_get_trial_by_id(self, trial_id: str) -> Trial: f"Failed to get trial '{trial_id}' after {N_UNSAFE_RETRIES} retries." ) - def put_updated_trial(self, trial: Trial) -> None: - """Update the trial.""" + def put_updated_trial( + self, + trial: Trial, + *, + hints: list[TrialWriteHint] | TrialWriteHint | None = None, + ) -> None: + """Update the trial. + + Args: + trial: The trial to update. + hints: The hints to use when updating the trial. Defines what files need + to be updated. + If you don't know, leave `None`, this is a micro-optimization. + """ with self._trial_lock.lock(): - self._trials.update_trial(trial) - self._trials.write_version_file() + self._trials.update_trial(trial, hints=hints) @overload def lock_and_get_next_pending_trial(self) -> Trial | None: ... diff --git a/neps/status/status.py b/neps/status/status.py index 4b47eaeb..b738b353 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -9,6 +9,7 @@ import pandas as pd +from neps.runtime import get_workers_neps_state from neps.state.filebased import FileLocker from neps.state.neps_state import NePSState from neps.state.trial import Trial @@ -37,7 +38,11 @@ def get_summary_dict( # NOTE: We don't lock the shared state since we are just reading and don't need to # make decisions based on the state - shared_state = NePSState.create_or_load(root_directory, load_only=True) + try: + shared_state = get_workers_neps_state() + except RuntimeError: + shared_state = NePSState.create_or_load(root_directory) + trials = shared_state.lock_and_read_trials() evaluated: dict[ConfigID, _ConfigResultForStats] = {} diff --git a/neps/utils/files.py b/neps/utils/files.py index 95eb4eee..f2bdaad2 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -10,6 +10,14 @@ import yaml +try: + from yaml import ( + CSafeDumper as SafeDumper, # type: ignore + CSafeLoader as SafeLoader, # type: ignore + ) +except ImportError: + from yaml import SafeDumper, SafeLoader # type: ignore + def serializable_format(data: Any) -> Any: # noqa: PLR0911 """Format data to be serializable.""" @@ -47,7 +55,7 @@ def serialize(data: Any, path: Path | str, *, sort_keys: bool = True) -> None: path = Path(path) with path.open("w") as file_stream: try: - return yaml.safe_dump(data, file_stream, sort_keys=sort_keys) + return yaml.dump(data, file_stream, SafeDumper, sort_keys=sort_keys) except yaml.representer.RepresenterError as e: raise TypeError( "Could not serialize to yaml! The object " @@ -58,7 +66,7 @@ def serialize(data: Any, path: Path | str, *, sort_keys: bool = True) -> None: def deserialize(path: Path | str) -> dict[str, Any]: """Deserialize data from a yaml file.""" with Path(path).open("r") as file_stream: - data = yaml.full_load(file_stream) + data = yaml.load(file_stream, SafeLoader) if not isinstance(data, dict): raise TypeError( From 363c94daa3ec100987f95340869b4501eb171038 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 21:15:18 +0100 Subject: [PATCH 32/56] optim: save torch tensors as numpy --- neps/state/filebased.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index f6324860..eb9efe21 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -130,8 +130,8 @@ class ReaderWriterSeedSnapshot: PY_RNG_TUPLE_FILENAME: ClassVar = "py_rng.npy" NP_RNG_STATE_FILENAME: ClassVar = "np_rng_state.npy" - TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.pt" - TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.pt" + TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.npy" + TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.npy" SEED_INFO_FILENAME: ClassVar = "seed_info.json" @classmethod @@ -159,7 +159,9 @@ def read(cls, directory: Path) -> SeedSnapshot: import torch if torch_rng_path_exists: - torch_rng_state = torch.load(torch_rng_path, weights_only=True) + # OPTIM: This ends up being much faster to go to numpy + _bytes = np.fromfile(torch_rng_path, dtype=np.uint8) + torch_rng_state = torch.tensor(_bytes, dtype=torch.uint8) if torch_cuda_rng_path_exists: # By specifying `weights_only=True`, it disables arbitrary object loading @@ -211,7 +213,8 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: if snapshot.torch_rng is not None: import torch - torch.save(snapshot.torch_rng, torch_rng_path) + # OPTIM: This ends up being much faster to go to numpy + snapshot.torch_rng.numpy().tofile(torch_rng_path) if snapshot.torch_cuda_rng is not None: import torch From e0527e47368dc438d87ce112ca74248293881972 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 21:58:40 +0100 Subject: [PATCH 33/56] fix: Sample successful => evaluating --- neps/runtime.py | 8 +++++++- neps/state/filebased.py | 6 ++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 9051ece9..de98adac 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -382,7 +382,9 @@ def _get_next_trial(self) -> Trial | Literal["break"]: ) return earliest_pending - # Otherwise, we release the trial lock while sampling + # NOTE: It's important to release the trial lock before sampling + # as otherwise, any other service, such as reporting the result + # of a trial sampled_trial = self.state._sample_trial( optimizer=self.optimizer, worker_id=self.worker_id, @@ -391,6 +393,10 @@ def _get_next_trial(self) -> Trial | Literal["break"]: with self.state._trial_lock.lock(): try: + sampled_trial.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) self.state._trials.new_trial(sampled_trial) return sampled_trial except TrialAlreadyExistsError as e: diff --git a/neps/state/filebased.py b/neps/state/filebased.py index eb9efe21..6a112cb6 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -98,6 +98,9 @@ def write( state_path.write_text(trial.state.value, encoding="utf-8") case _: raise ValueError(f"Invalid hint: {hints}") + elif isinstance(hints, list): + for hint in hints: + cls.write(trial, directory, hints=hint) elif hints is None: # We don't know, write everything serialize(trial.config, config_path) @@ -114,8 +117,7 @@ def write( report_path = directory / cls.REPORT_FILENAME serialize(asdict(trial.report), report_path) else: - for hint in hints: - cls.write(trial, directory, hints=hint) + raise ValueError(f"Invalid hint: {hints}") TrialReaderWriter: Final = ReaderWriterTrial() From dac1f3467e2a20c34a4bb728f0cea52d7334bbf8 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 2 Dec 2024 23:19:05 +0100 Subject: [PATCH 34/56] fix: grace period for os sync --- neps/env.py | 10 ++++++++++ neps/runtime.py | 39 ++++++++++++++++++++++++++++----------- neps/state/filebased.py | 38 +++++++++++++++++++++++++------------- neps/state/neps_state.py | 4 ++-- neps/utils/files.py | 17 ++++++++++++++--- 5 files changed, 79 insertions(+), 29 deletions(-) diff --git a/neps/env.py b/neps/env.py index c614ebac..f92ee21f 100644 --- a/neps/env.py +++ b/neps/env.py @@ -64,6 +64,16 @@ def is_nullable(e: str) -> bool: parse=lambda e: None if is_nullable(e) else float(e), default=120, ) +FS_SYNC_GRACE_BASE = get_env( + "NEPS_FS_SYNC_GRACE_BASE", + parse=float, + default=0.05, # Keep it low initially to not punish synced os +) +FS_SYNC_GRACE_INC = get_env( + "NEPS_FS_SYNC_GRACE_INC", + parse=float, + default=0.1, +) # NOTE: We want this to be greater than the trials filelock, so that # anything requesting to just update the trials is more likely to obtain it diff --git a/neps/runtime.py b/neps/runtime.py index de98adac..e2d9084c 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -20,6 +20,7 @@ ) from neps.env import ( + FS_SYNC_GRACE_INC, LINUX_FILELOCK_FUNCTION, MAX_RETRIES_CREATE_LOAD_STATE, MAX_RETRIES_GET_NEXT_TRIAL, @@ -32,6 +33,7 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial +from neps.state.filebased import FileLocker from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -350,8 +352,6 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. with self.state._state_lock.lock(): - # With the trial lock, we'll load everything in, if we have a pending - # config, use that and return. with self.state._trial_lock.lock(): trials = self.state._trials.latest() @@ -380,6 +380,11 @@ def _get_next_trial(self) -> Trial | Literal["break"]: earliest_pending, hints=["metadata", "state"], ) + logger.info( + "Worker '%s' picked up pending trial: %s.", + self.worker_id, + earliest_pending.id, + ) return earliest_pending # NOTE: It's important to release the trial lock before sampling @@ -398,26 +403,38 @@ def _get_next_trial(self) -> Trial | Literal["break"]: worker_id=self.worker_id, ) self.state._trials.new_trial(sampled_trial) + logger.info( + "Worker '%s' sampled new trial: %s.", + self.worker_id, + sampled_trial.id, + ) return sampled_trial except TrialAlreadyExistsError as e: if sampled_trial.id in trials: - logger.warning( - "The new sampled trial was given an id of '%s', yet this already" + logger.error( + "The new sampled trial was given an id of '%s', yet this" " exists in the loaded in trials given to the optimizer. This" " indicates a bug with the optimizers allocation of ids.", sampled_trial.id, ) else: + _grace = FileLocker._GRACE + _inc = FS_SYNC_GRACE_INC logger.warning( - "The new sampled trial was given an id of '%s', which is not one" - " that was loaded in by the optimizer. This indicates that" - " configuration '%s' was put on disk during the time that this" - " worker had the optimizer state lock OR that after obtaining the" - " optimizer state lock, somehow this configuration failed to be" - " loaded in and passed to the optimizer.", - sampled_trial.id, + "The new sampled trial was given an id of '%s', which is not" + " one that was loaded in by the optimizer. This is usually" + " an indication that the file-system you are running on" + " is not atmoic in synchoronizing file operations." + " We have attempted to stabalize this but milage may vary." + " We are incrementing a grace period for file-locks from" + " '%s's to '%s's. You can control the initial" + " grace with 'NEPS_FS_SYNC_GRACE_BASE' and the increment with" + " 'NEPS_FS_SYNC_GRACE_INC'.", sampled_trial.id, + _grace, + _grace + _inc, ) + FileLocker._increse_grace(_inc) raise e # Forgive me lord, for I have sinned, this function is atrocious but complicated diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 6a112cb6..bbb81d14 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -2,7 +2,9 @@ import json import logging +import os import pprint +import time from collections.abc import Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass @@ -12,14 +14,12 @@ import numpy as np import portalocker as pl -from neps.env import ( - ENV_VARS_USED, -) +from neps.env import ENV_VARS_USED, FS_SYNC_GRACE_BASE from neps.state.err_dump import ErrDump from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial -from neps.utils.files import deserialize, serialize +from neps.utils.files import atomic_write, deserialize, serialize logger = logging.getLogger(__name__) K = TypeVar("K") @@ -80,7 +80,7 @@ def write( case "config": serialize(trial.config, config_path) case "metadata": - with metadata_path.open("w") as f: + with atomic_write(metadata_path, "w") as f: json.dump(asdict(trial.metadata), f) if trial.metadata.previous_trial_id is not None: @@ -104,7 +104,7 @@ def write( elif hints is None: # We don't know, write everything serialize(trial.config, config_path) - with metadata_path.open("w") as f: + with atomic_write(metadata_path, "w") as f: json.dump(asdict(trial.metadata), f) if trial.metadata.previous_trial_id is not None: @@ -206,22 +206,25 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: "py_rng_version": py_rng_version, "py_guass_next": py_guass_next, } - with seedinfo_path.open("w") as f: + with atomic_write(seedinfo_path, "w") as f: json.dump(seed_info, f) np_rng_state = snapshot.np_rng[1] - np_rng_state.tofile(np_rng_path) + with atomic_write(np_rng_path, "wb") as f: + np_rng_state.tofile(f) if snapshot.torch_rng is not None: import torch # OPTIM: This ends up being much faster to go to numpy - snapshot.torch_rng.numpy().tofile(torch_rng_path) + with atomic_write(torch_rng_path, "wb") as f: + snapshot.torch_rng.numpy().tofile(f) if snapshot.torch_cuda_rng is not None: import torch - torch.save(snapshot.torch_cuda_rng, torch_cuda_rng_path) + with atomic_write(torch_cuda_rng_path, "wb") as f: + torch.save(snapshot.torch_cuda_rng, f) @dataclass @@ -265,7 +268,7 @@ def read(cls, directory: Path) -> OptimizationState: @classmethod def write(cls, info: OptimizationState, directory: Path) -> None: info_path = directory / cls.STATE_FILE_NAME - with info_path.open("w") as f: + with atomic_write(info_path, "w") as f: json.dump(asdict(info), f) @@ -284,7 +287,7 @@ def read(cls, directory: Path) -> ErrDump: @classmethod def write(cls, err_dump: ErrDump, directory: Path) -> None: errors_path = directory / "errors.jsonl" - with errors_path.open("w") as f: + with atomic_write(errors_path, "w") as f: lines = [json.dumps(asdict(trial_err)) for trial_err in err_dump.errs] f.write("\n".join(lines)) @@ -305,10 +308,15 @@ class FileLocker: lock_path: Path poll: float timeout: float | None + _GRACE: ClassVar = FS_SYNC_GRACE_BASE def __post_init__(self) -> None: self.lock_path = self.lock_path.resolve().absolute() + @classmethod + def _increse_grace(cls, grace: float) -> None: + cls._GRACE = grace + cls._GRACE + @contextmanager def lock( self, @@ -317,6 +325,7 @@ def lock( ) -> Iterator[None]: self.lock_path.parent.mkdir(parents=True, exist_ok=True) self.lock_path.touch(exist_ok=True) + try: with pl.Lock( self.lock_path, @@ -324,8 +333,11 @@ def lock( timeout=self.timeout, flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, - ): + ) as fh: + time.sleep(self._GRACE) # Give the lock some time to yield + fh.flush() + os.fsync(fh) except pl.exceptions.LockException as e: raise pl.exceptions.LockException( f"Failed to acquire lock after timeout of {self.timeout} seconds." diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 67f5ee20..4f6c53ed 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -46,6 +46,7 @@ from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Report, Trial +from neps.utils.files import atomic_write if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer @@ -126,7 +127,7 @@ def update_trial( self._write_version_file() def _write_version_file(self) -> None: - with self.version_file.open("wb") as f: + with atomic_write(self.version_file, "wb") as f: pickle.dump(self.versions, f) def trials_in_memory(self) -> dict[str, Trial]: @@ -368,7 +369,6 @@ def _report_trial_evaluation( trial.report = report self._trials.update_trial(trial, hints=["report", "metadata", "state"]) - logger.debug("Updated trial '%s' with status '%s'", trial.id, trial.state) if report.err is not None: with self._err_lock.lock(): err_dump = self._shared_errors.latest() diff --git a/neps/utils/files.py b/neps/utils/files.py index f2bdaad2..2ba9894e 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -3,10 +3,12 @@ from __future__ import annotations import dataclasses -from collections.abc import Iterable, Mapping +import os +from collections.abc import Iterable, Iterator, Mapping +from contextlib import contextmanager from enum import Enum from pathlib import Path -from typing import Any +from typing import IO, Any import yaml @@ -19,6 +21,15 @@ from yaml import SafeDumper, SafeLoader # type: ignore +@contextmanager +def atomic_write(file_path: Path | str, *args: Any, **kwargs: Any) -> Iterator[IO]: + with open(file_path, *args, **kwargs) as file_stream: # noqa: PTH123 + yield file_stream + file_stream.flush() + os.fsync(file_stream.fileno()) + file_stream.close() + + def serializable_format(data: Any) -> Any: # noqa: PLR0911 """Format data to be serializable.""" if hasattr(data, "serialize"): @@ -53,7 +64,7 @@ def serialize(data: Any, path: Path | str, *, sort_keys: bool = True) -> None: """Serialize data to a yaml file.""" data = serializable_format(data) path = Path(path) - with path.open("w") as file_stream: + with atomic_write(path, "w") as file_stream: try: return yaml.dump(data, file_stream, SafeDumper, sort_keys=sort_keys) except yaml.representer.RepresenterError as e: From 209912d66fdea93224c9a138bf36bc93e4277765 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 10:31:13 +0100 Subject: [PATCH 35/56] refactor: Rename to --- neps/runtime.py | 2 +- neps/state/neps_state.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index e2d9084c..30930d5c 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -351,7 +351,7 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. - with self.state._state_lock.lock(): + with self.state._optimizer_lock.lock(): with self.state._trial_lock.lock(): trials = self.state._trials.latest() diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 4f6c53ed..b9025c9f 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -224,7 +224,7 @@ class NePSState: _trial_lock: FileLocker = field(repr=False) _trials: TrialRepo = field(repr=False) - _state_lock: FileLocker = field(repr=False) + _optimizer_lock: FileLocker = field(repr=False) _optimizer_info: VersionedResource[OptimizerInfo] = field(repr=False) _seed_snapshot: VersionedResource[SeedSnapshot] = field(repr=False) _optimizer_state: VersionedResource[OptimizationState] = field(repr=False) @@ -239,7 +239,7 @@ def lock_and_read_trials(self) -> dict[str, Trial]: def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: """Acquire the state lock and sample a trial.""" - with self._state_lock.lock(): + with self._optimizer_lock.lock(): with self._trial_lock.lock(): trials = self._trials.latest() @@ -394,12 +394,12 @@ def lock_and_get_errors(self) -> ErrDump: def lock_and_get_optimizer_info(self) -> OptimizerInfo: """Get the optimizer information.""" - with self._state_lock.lock(): + with self._optimizer_lock.lock(): return self._optimizer_info.latest() def lock_and_get_optimizer_state(self) -> OptimizationState: """Get the optimizer state.""" - with self._state_lock.lock(): + with self._optimizer_lock.lock(): return self._optimizer_state.latest() def lock_and_get_trial_by_id(self, trial_id: str) -> Trial: From b5a01f5e813b1796d0e19c689a895f4ab677dc68 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 10:49:20 +0100 Subject: [PATCH 36/56] refactor: Use pickle cache --- neps/state/neps_state.py | 51 ++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index b9025c9f..925c5e43 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -46,7 +46,6 @@ from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Report, Trial -from neps.utils.files import atomic_write if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer @@ -76,13 +75,14 @@ def make_sha() -> Version: # TODO: Ergonomics of this class sucks @dataclass class TrialRepo: + CACHE_FILE_NAME = ".trial_cache.pkl" + directory: Path - version_file: Path - trial_cache: dict[str, Trial] = field(default_factory=dict) - versions: dict[str, Version] = field(default_factory=dict) + cache_path: Path = field(init=False) def __post_init__(self) -> None: self.directory.mkdir(parents=True, exist_ok=True) + self.cache_path = self.directory / self.CACHE_FILE_NAME def list_trial_ids(self) -> list[str]: return [ @@ -92,23 +92,11 @@ def list_trial_ids(self) -> list[str]: ] def latest(self) -> dict[str, Trial]: - if not self.version_file.exists(): + if not self.cache_path.exists(): return {} - with self.version_file.open("rb") as f: - versions_on_disk = pickle.load(f) # noqa: S301 - - stale: dict[str, Version] = { - k: v - for k, v in versions_on_disk.items() - if self.versions.get(k, "__not_found__") != v - } - for trial_id, loaded_version in stale.items(): - loaded_trial = self.load_trial_from_disk(trial_id) - self.trial_cache[trial_id] = loaded_trial - self.versions[trial_id] = loaded_version - - return self.trial_cache + with self.cache_path.open("rb") as f: + return pickle.load(f) # noqa: S301 def new_trial(self, trial: Trial) -> None: config_path = self.directory / f"config_{trial.id}" @@ -118,20 +106,17 @@ def new_trial(self, trial: Trial) -> None: self.update_trial(trial, hints=None) def update_trial( - self, trial: Trial, *, hints: list[TrialWriteHint] | TrialWriteHint | None = None + self, + trial: Trial, + *, + hints: list[TrialWriteHint] | TrialWriteHint | None = None, ) -> None: - TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=hints) - self.trial_cache[trial.id] = trial - new_version = make_sha() - self.versions[trial.id] = new_version - self._write_version_file() - - def _write_version_file(self) -> None: - with atomic_write(self.version_file, "wb") as f: - pickle.dump(self.versions, f) + trials = self.latest() + with self.cache_path.open("wb") as f: + trials[trial.id] = trial + pickle.dump(trials, f) - def trials_in_memory(self) -> dict[str, Trial]: - return self.trial_cache + TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=hints) def load_trial_from_disk(self, trial_id: str) -> Trial: config_path = self.directory / f"config_{trial_id}" @@ -596,14 +581,14 @@ def create_or_load( return cls( path=path, - _trials=TrialRepo(config_dir, version_file=config_dir / ".versions"), + _trials=TrialRepo(config_dir), # Locks, _trial_lock=FileLocker( lock_path=path / ".configs.lock", poll=TRIAL_FILELOCK_POLL, timeout=TRIAL_FILELOCK_TIMEOUT, ), - _state_lock=FileLocker( + _optimizer_lock=FileLocker( lock_path=path / ".state.lock", poll=STATE_FILELOCK_POLL, timeout=STATE_FILELOCK_TIMEOUT, From 65b2bf9911a516dbbcde162f13a61ed34f286fca Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 10:56:10 +0100 Subject: [PATCH 37/56] refactor: Move GRACE usage to worker --- neps/runtime.py | 10 +++++++--- neps/state/filebased.py | 9 +-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 30930d5c..5f232314 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -14,12 +14,14 @@ from typing import ( TYPE_CHECKING, Any, + ClassVar, Generic, Literal, TypeVar, ) from neps.env import ( + FS_SYNC_GRACE_BASE, FS_SYNC_GRACE_INC, LINUX_FILELOCK_FUNCTION, MAX_RETRIES_CREATE_LOAD_STATE, @@ -33,7 +35,6 @@ WorkerRaiseError, ) from neps.state._eval import evaluate_trial -from neps.state.filebased import FileLocker from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings @@ -156,6 +157,8 @@ class DefaultWorker(Generic[Loc]): worker_cumulative_evaluation_time_seconds: float = 0.0 """The time spent evaluating configurations by this worker.""" + _GRACE: ClassVar = FS_SYNC_GRACE_BASE + @classmethod def new( cls, @@ -353,6 +356,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. with self.state._optimizer_lock.lock(): with self.state._trial_lock.lock(): + time.sleep(self._GRACE) # Give the lock some time to trials = self.state._trials.latest() if self._requires_global_stopping_criterion: @@ -418,7 +422,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: sampled_trial.id, ) else: - _grace = FileLocker._GRACE + _grace = DefaultWorker._GRACE _inc = FS_SYNC_GRACE_INC logger.warning( "The new sampled trial was given an id of '%s', which is not" @@ -434,7 +438,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: _grace, _grace + _inc, ) - FileLocker._increse_grace(_inc) + DefaultWorker._GRACE = _grace + FS_SYNC_GRACE_INC raise e # Forgive me lord, for I have sinned, this function is atrocious but complicated diff --git a/neps/state/filebased.py b/neps/state/filebased.py index bbb81d14..7b25cc7f 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -4,7 +4,6 @@ import logging import os import pprint -import time from collections.abc import Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass @@ -14,7 +13,7 @@ import numpy as np import portalocker as pl -from neps.env import ENV_VARS_USED, FS_SYNC_GRACE_BASE +from neps.env import ENV_VARS_USED from neps.state.err_dump import ErrDump from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot @@ -308,15 +307,10 @@ class FileLocker: lock_path: Path poll: float timeout: float | None - _GRACE: ClassVar = FS_SYNC_GRACE_BASE def __post_init__(self) -> None: self.lock_path = self.lock_path.resolve().absolute() - @classmethod - def _increse_grace(cls, grace: float) -> None: - cls._GRACE = grace + cls._GRACE - @contextmanager def lock( self, @@ -334,7 +328,6 @@ def lock( flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, ) as fh: - time.sleep(self._GRACE) # Give the lock some time to yield fh.flush() os.fsync(fh) From b146d050aee441a6776400fe56e4e79d3a49c93a Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 11:02:13 +0100 Subject: [PATCH 38/56] fix: Load in from disk if cache missing --- neps/state/neps_state.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 925c5e43..85bbcd3f 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -93,6 +93,16 @@ def list_trial_ids(self) -> list[str]: def latest(self) -> dict[str, Trial]: if not self.cache_path.exists(): + # If we end up with no cache but there are trials on disk, we need to + # read them in. However we will not save back the cache here in fear of + # overwriting + if any(path.name.startswith("config_") for path in self.directory.iterdir()): + trial_ids = self.list_trial_ids() + return { + trial_id: self.load_trial_from_disk(trial_id) + for trial_id in trial_ids + } + return {} with self.cache_path.open("rb") as f: From 56d8acc833003a115fa701ae26af9ba4089cecc4 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 13:36:24 +0100 Subject: [PATCH 39/56] debug: Add log file --- neps/runtime.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/neps/runtime.py b/neps/runtime.py index 5f232314..7b0df480 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -357,6 +357,11 @@ def _get_next_trial(self) -> Trial | Literal["break"]: with self.state._optimizer_lock.lock(): with self.state._trial_lock.lock(): time.sleep(self._GRACE) # Give the lock some time to + logger.info("I, MR WORKER %s obtained thel lock", self.worker_id) + DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") + with DEBUG_COUNT_FILE.open("a") as f: + f.write(f"{self.worker_id}\n") + trials = self.state._trials.latest() if self._requires_global_stopping_criterion: @@ -407,6 +412,9 @@ def _get_next_trial(self) -> Trial | Literal["break"]: worker_id=self.worker_id, ) self.state._trials.new_trial(sampled_trial) + logger.info( + "I, MR WORKER %s, SAMPLED %s", self.worker_id, sampled_trial.id + ) logger.info( "Worker '%s' sampled new trial: %s.", self.worker_id, From 925d9124ee1d2347ff090528e48f9298b6c51d78 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 13:38:35 +0100 Subject: [PATCH 40/56] debug: more... --- neps/runtime.py | 1 + 1 file changed, 1 insertion(+) diff --git a/neps/runtime.py b/neps/runtime.py index 7b0df480..87186e23 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -395,6 +395,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: earliest_pending.id, ) return earliest_pending + logger.info("I, MR WORKER %s released thel lock", self.worker_id) # NOTE: It's important to release the trial lock before sampling # as otherwise, any other service, such as reporting the result From 981eb2d54b652f5613aa5974ec617e21a2cd2d71 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 13:43:43 +0100 Subject: [PATCH 41/56] debug: MORE... --- neps/runtime.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/neps/runtime.py b/neps/runtime.py index 87186e23..158af9c3 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -360,7 +360,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: logger.info("I, MR WORKER %s obtained thel lock", self.worker_id) DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") with DEBUG_COUNT_FILE.open("a") as f: - f.write(f"{self.worker_id}\n") + f.write(f"locked: {time.time()}\n") trials = self.state._trials.latest() @@ -396,6 +396,8 @@ def _get_next_trial(self) -> Trial | Literal["break"]: ) return earliest_pending logger.info("I, MR WORKER %s released thel lock", self.worker_id) + with DEBUG_COUNT_FILE.open("a") as f: + f.write(f"unlocked: {time.time()}\n") # NOTE: It's important to release the trial lock before sampling # as otherwise, any other service, such as reporting the result From cc6073827cdf2be585a224feb6ce366cc6b9f19f Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 13:57:19 +0100 Subject: [PATCH 42/56] debug: last ditch effort --- neps/runtime.py | 4 ++-- neps/state/filebased.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 158af9c3..baf54e05 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -355,15 +355,15 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. with self.state._optimizer_lock.lock(): + time.sleep(self._GRACE) # Give the lock some time with self.state._trial_lock.lock(): - time.sleep(self._GRACE) # Give the lock some time to + time.sleep(self._GRACE) # Give the lock some time logger.info("I, MR WORKER %s obtained thel lock", self.worker_id) DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") with DEBUG_COUNT_FILE.open("a") as f: f.write(f"locked: {time.time()}\n") trials = self.state._trials.latest() - if self._requires_global_stopping_criterion: should_stop = self._check_global_stopping_criterion(trials) if should_stop is not False: diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 7b25cc7f..a1337837 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -5,6 +5,7 @@ import os import pprint from collections.abc import Iterator +import time from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path @@ -323,11 +324,14 @@ def lock( try: with pl.Lock( self.lock_path, + mode="wb", check_interval=self.poll, timeout=self.timeout, flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, ) as fh: + fh.write(f"{time.time()}".encode("utf-8")) # noqa: UP012 + os.fsync(fh) yield fh.flush() os.fsync(fh) From f3d23d17274adc35e63c7612401975076b9a1e38 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:09:03 +0100 Subject: [PATCH 43/56] debug: More info on lock method --- neps/runtime.py | 5 ++++- neps/state/filebased.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index baf54e05..4ef2592a 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -681,7 +681,10 @@ def _launch_runtime( # noqa: PLR0913 elif LINUX_FILELOCK_FUNCTION.lower() == "lockf": setattr(portalocker_lock_module, "LOCKER", fcntl.lockf) else: - pass + raise ValueError( + f"Unknown file-locking function '{LINUX_FILELOCK_FUNCTION}'." + " Must be one of 'flock' or 'lockf'." + ) except ImportError: pass diff --git a/neps/state/filebased.py b/neps/state/filebased.py index a1337837..3142d2d8 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -4,8 +4,8 @@ import logging import os import pprint -from collections.abc import Iterator import time +from collections.abc import Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path @@ -330,6 +330,9 @@ def lock( flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, ) as fh: + import portalocker.portalocker as pl_module + + logger.info(pl_module.LOCKER) fh.write(f"{time.time()}".encode("utf-8")) # noqa: UP012 os.fsync(fh) yield From c440b7efde38608461919cca8497b4916c91f0df Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:46:03 +0100 Subject: [PATCH 44/56] debug: MOREEEEEEEE --- neps/runtime.py | 2 -- neps/state/filebased.py | 12 +++++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 4ef2592a..ba84b051 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -358,7 +358,6 @@ def _get_next_trial(self) -> Trial | Literal["break"]: time.sleep(self._GRACE) # Give the lock some time with self.state._trial_lock.lock(): time.sleep(self._GRACE) # Give the lock some time - logger.info("I, MR WORKER %s obtained thel lock", self.worker_id) DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") with DEBUG_COUNT_FILE.open("a") as f: f.write(f"locked: {time.time()}\n") @@ -395,7 +394,6 @@ def _get_next_trial(self) -> Trial | Literal["break"]: earliest_pending.id, ) return earliest_pending - logger.info("I, MR WORKER %s released thel lock", self.worker_id) with DEBUG_COUNT_FILE.open("a") as f: f.write(f"unlocked: {time.time()}\n") diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 3142d2d8..794e9a1b 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -317,14 +317,13 @@ def lock( self, *, fail_if_locked: bool = False, + worker_id: str | None = None, ) -> Iterator[None]: self.lock_path.parent.mkdir(parents=True, exist_ok=True) - self.lock_path.touch(exist_ok=True) try: with pl.Lock( self.lock_path, - mode="wb", check_interval=self.poll, timeout=self.timeout, flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, @@ -332,7 +331,14 @@ def lock( ) as fh: import portalocker.portalocker as pl_module - logger.info(pl_module.LOCKER) + if worker_id is not None: + logger.debug( + "Worker %s acquired lock on %s using %s at %s", + worker_id, + self.lock_path, + pl_module.LOCKER, + time.time(), + ) fh.write(f"{time.time()}".encode("utf-8")) # noqa: UP012 os.fsync(fh) yield From 129bf3260e68c3dcacf566262d63a349efc11928 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:47:51 +0100 Subject: [PATCH 45/56] debug: "...and more" --- neps/state/filebased.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 794e9a1b..9693f381 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -339,7 +339,6 @@ def lock( pl_module.LOCKER, time.time(), ) - fh.write(f"{time.time()}".encode("utf-8")) # noqa: UP012 os.fsync(fh) yield fh.flush() From 0603b438cf1d64c3a1a2b43d36195ceef40121b1 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:49:38 +0100 Subject: [PATCH 46/56] debug: "you know the drill" --- neps/state/filebased.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 9693f381..d50ae3b7 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import json import logging import os @@ -353,3 +354,12 @@ def lock( " environment variables to increase the timeout:" f"\n\n{pprint.pformat(ENV_VARS_USED)}" ) from e + finally: + if worker_id is not None: + with contextlib.suppress(Exception): + logger.debug( + "Worker %s released lock on %s at %s", + worker_id, + self.lock_path, + time.time(), + ) From 2339ce77a64ad876570114c79de24f4910ce8a95 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:51:15 +0100 Subject: [PATCH 47/56] debug: "..." --- neps/runtime.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index ba84b051..b22f9a0c 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -354,9 +354,9 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. - with self.state._optimizer_lock.lock(): + with self.state._optimizer_lock.lock(worker_id=self.worker_id): time.sleep(self._GRACE) # Give the lock some time - with self.state._trial_lock.lock(): + with self.state._trial_lock.lock(worker_id=self.worker_id): time.sleep(self._GRACE) # Give the lock some time DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") with DEBUG_COUNT_FILE.open("a") as f: @@ -406,16 +406,13 @@ def _get_next_trial(self) -> Trial | Literal["break"]: trials=trials, ) - with self.state._trial_lock.lock(): + with self.state._trial_lock.lock(worker_id=self.worker_id): try: sampled_trial.set_evaluating( time_started=time.time(), worker_id=self.worker_id, ) self.state._trials.new_trial(sampled_trial) - logger.info( - "I, MR WORKER %s, SAMPLED %s", self.worker_id, sampled_trial.id - ) logger.info( "Worker '%s' sampled new trial: %s.", self.worker_id, @@ -568,7 +565,7 @@ def run(self) -> None: # noqa: C901, PLR0915 # We do not retry this, as if some other worker has # managed to manipulate this trial in the meantime, # then something has gone wrong - with self.state._trial_lock.lock(): + with self.state._trial_lock.lock(worker_id=self.worker_id): self.state._report_trial_evaluation( trial=evaluated_trial, report=report, From 7d52f3d51051c1273f6d9f4e4994e88fa17dff7e Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 14:55:44 +0100 Subject: [PATCH 48/56] fix: Delay config dir creation until config written --- neps/state/neps_state.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 85bbcd3f..e46d25d9 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -112,8 +112,13 @@ def new_trial(self, trial: Trial) -> None: config_path = self.directory / f"config_{trial.id}" if config_path.exists(): raise TrialAlreadyExistsError(trial.id, config_path) + trials = self.latest() + with self.cache_path.open("wb") as f: + trials[trial.id] = trial + pickle.dump(trials, f) + config_path.mkdir(parents=True, exist_ok=True) - self.update_trial(trial, hints=None) + TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=None) def update_trial( self, From 5093b73ff769d635c14381cc8387247d0d0c82be Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 15:28:55 +0100 Subject: [PATCH 49/56] fix: Yay, remove debugging statements --- neps/runtime.py | 9 +-------- neps/state/filebased.py | 7 ++----- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index b22f9a0c..35228c57 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -355,13 +355,8 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. with self.state._optimizer_lock.lock(worker_id=self.worker_id): - time.sleep(self._GRACE) # Give the lock some time with self.state._trial_lock.lock(worker_id=self.worker_id): - time.sleep(self._GRACE) # Give the lock some time - DEBUG_COUNT_FILE = Path(self.state.path / "DEBUG_COUNT_FILE") - with DEBUG_COUNT_FILE.open("a") as f: - f.write(f"locked: {time.time()}\n") - + time.sleep(self._GRACE) # Give the FS some time to sync trials = self.state._trials.latest() if self._requires_global_stopping_criterion: should_stop = self._check_global_stopping_criterion(trials) @@ -394,8 +389,6 @@ def _get_next_trial(self) -> Trial | Literal["break"]: earliest_pending.id, ) return earliest_pending - with DEBUG_COUNT_FILE.open("a") as f: - f.write(f"unlocked: {time.time()}\n") # NOTE: It's important to release the trial lock before sampling # as otherwise, any other service, such as reporting the result diff --git a/neps/state/filebased.py b/neps/state/filebased.py index d50ae3b7..d550d18d 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -330,17 +330,14 @@ def lock( flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, ) as fh: - import portalocker.portalocker as pl_module - if worker_id is not None: logger.debug( - "Worker %s acquired lock on %s using %s at %s", + "Worker %s acquired lock on %s at %s", worker_id, self.lock_path, - pl_module.LOCKER, time.time(), ) - os.fsync(fh) + yield fh.flush() os.fsync(fh) From a51a286481a0169676ef6b690aa48e909a634e2a Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Mon, 9 Dec 2024 20:19:35 +0100 Subject: [PATCH 50/56] optim: Speed optimizations --- neps/env.py | 21 +++++- neps/runtime.py | 13 ++-- neps/state/filebased.py | 140 ++++++++++++++------------------------- neps/state/neps_state.py | 72 ++++++++++++++++---- neps/utils/files.py | 62 ++++++++++++----- 5 files changed, 183 insertions(+), 125 deletions(-) diff --git a/neps/env.py b/neps/env.py index f92ee21f..01a5883b 100644 --- a/neps/env.py +++ b/neps/env.py @@ -4,7 +4,7 @@ import os from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, Literal, TypeVar T = TypeVar("T") V = TypeVar("V") @@ -28,6 +28,13 @@ def is_nullable(e: str) -> bool: return e.lower() in ("none", "n", "null") +def yaml_or_json(e: str) -> Literal["yaml", "json"]: + """Check if an environment variable is either yaml or json.""" + if e.lower() in ("yaml", "json"): + return e.lower() # type: ignore + raise ValueError(f"Expected 'yaml' or 'json', got '{e}'.") + + LINUX_FILELOCK_FUNCTION = get_env( "NEPS_LINUX_FILELOCK_FUNCTION", parse=str, @@ -67,7 +74,7 @@ def is_nullable(e: str) -> bool: FS_SYNC_GRACE_BASE = get_env( "NEPS_FS_SYNC_GRACE_BASE", parse=float, - default=0.05, # Keep it low initially to not punish synced os + default=0.00, # Keep it low initially to not punish synced os ) FS_SYNC_GRACE_INC = get_env( "NEPS_FS_SYNC_GRACE_INC", @@ -100,3 +107,13 @@ def is_nullable(e: str) -> bool: parse=lambda e: None if is_nullable(e) else float(e), default=120, ) +TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION = get_env( + "NEPS_TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION", + parse=int, + default=30, +) +CONFIG_SERIALIZE_FORMAT: Literal["yaml", "json"] = get_env( # type: ignore + "NEPS_CONFIG_SERIALIZE_FORMAT", + parse=yaml_or_json, + default="yaml", +) diff --git a/neps/runtime.py b/neps/runtime.py index 35228c57..7f5b1ae5 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -355,9 +355,17 @@ def _requires_global_stopping_criterion(self) -> bool: def _get_next_trial(self) -> Trial | Literal["break"]: # If there are no global stopping criterion, we can no just return early. with self.state._optimizer_lock.lock(worker_id=self.worker_id): + # NOTE: It's important to release the trial lock before sampling + # as otherwise, any other service, such as reporting the result + # of a trial. Hence we do not lock these together with the above. with self.state._trial_lock.lock(worker_id=self.worker_id): - time.sleep(self._GRACE) # Give the FS some time to sync + # Give the file-system some time to sync if we encountered out-of-order + # issues with this worker. + if self._GRACE > 0: + time.sleep(self._GRACE) + trials = self.state._trials.latest() + if self._requires_global_stopping_criterion: should_stop = self._check_global_stopping_criterion(trials) if should_stop is not False: @@ -390,9 +398,6 @@ def _get_next_trial(self) -> Trial | Literal["break"]: ) return earliest_pending - # NOTE: It's important to release the trial lock before sampling - # as otherwise, any other service, such as reporting the result - # of a trial sampled_trial = self.state._sample_trial( optimizer=self.optimizer, worker_id=self.worker_id, diff --git a/neps/state/filebased.py b/neps/state/filebased.py index d550d18d..681c4d60 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -3,7 +3,7 @@ import contextlib import json import logging -import os +import pickle import pprint import time from collections.abc import Iterator @@ -15,12 +15,12 @@ import numpy as np import portalocker as pl -from neps.env import ENV_VARS_USED +from neps.env import CONFIG_SERIALIZE_FORMAT, ENV_VARS_USED from neps.state.err_dump import ErrDump from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial -from neps.utils.files import atomic_write, deserialize, serialize +from neps.utils.files import deserialize, serialize logger = logging.getLogger(__name__) K = TypeVar("K") @@ -35,8 +35,8 @@ class ReaderWriterTrial: # Report and config are kept as yaml since they are most likely to be # read - CONFIG_FILENAME = "config.yaml" - REPORT_FILENAME = "report.yaml" + CONFIG_FILENAME = f"config.{CONFIG_SERIALIZE_FORMAT}" + REPORT_FILENAME = f"report.{CONFIG_SERIALIZE_FORMAT}" # Metadata is put as json as it's more likely to be machine read and # is much faster. @@ -56,11 +56,15 @@ def read(cls, directory: Path) -> Trial: metadata = json.load(f) return Trial( - config=deserialize(config_path), + config=deserialize(config_path, file_format=CONFIG_SERIALIZE_FORMAT), metadata=Trial.MetaData(**metadata), state=Trial.State(state_path.read_text(encoding="utf-8").strip()), report=( - Trial.Report(**deserialize(report_path)) if report_path.exists() else None + Trial.Report( + **deserialize(report_path, file_format=CONFIG_SERIALIZE_FORMAT), + ) + if report_path.exists() + else None ), ) @@ -79,10 +83,16 @@ def write( if isinstance(hints, str): match hints: case "config": - serialize(trial.config, config_path) + serialize( + trial.config, + config_path, + check_serialized=False, + file_format=CONFIG_SERIALIZE_FORMAT, + ) case "metadata": - with atomic_write(metadata_path, "w") as f: - json.dump(asdict(trial.metadata), f) + data = asdict(trial.metadata) + with metadata_path.open("w") as f: + json.dump(data, f) if trial.metadata.previous_trial_id is not None: previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME @@ -94,7 +104,16 @@ def write( ) report_path = directory / cls.REPORT_FILENAME - serialize(asdict(trial.report), report_path) + _report = asdict(trial.report) + if (err := _report.get("err")) is not None: + _report["err"] = str(err) + + serialize( + _report, + report_path, + check_serialized=False, + file_format=CONFIG_SERIALIZE_FORMAT, + ) case "state": state_path.write_text(trial.state.value, encoding="utf-8") case _: @@ -104,19 +123,10 @@ def write( cls.write(trial, directory, hints=hint) elif hints is None: # We don't know, write everything - serialize(trial.config, config_path) - with atomic_write(metadata_path, "w") as f: - json.dump(asdict(trial.metadata), f) - - if trial.metadata.previous_trial_id is not None: - previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME - previous_trial_path.write_text(trial.metadata.previous_trial_id) - - state_path.write_text(trial.state.value, encoding="utf-8") + cls.write(trial, directory, hints=["config", "metadata", "state"]) if trial.report is not None: - report_path = directory / cls.REPORT_FILENAME - serialize(asdict(trial.report), report_path) + cls.write(trial, directory, hints="report") else: raise ValueError(f"Invalid hint: {hints}") @@ -130,75 +140,35 @@ class ReaderWriterSeedSnapshot: # It seems like they're all uint32 but I can't be sure. PY_RNG_STATE_DTYPE: ClassVar = np.int64 - - PY_RNG_TUPLE_FILENAME: ClassVar = "py_rng.npy" - NP_RNG_STATE_FILENAME: ClassVar = "np_rng_state.npy" - TORCH_RNG_STATE_FILENAME: ClassVar = "torch_rng_state.npy" - TORCH_CUDA_RNG_STATE_FILENAME: ClassVar = "torch_cuda_rng_state.npy" - SEED_INFO_FILENAME: ClassVar = "seed_info.json" + SEED_FILENAME: ClassVar = "seed.pickle" @classmethod def read(cls, directory: Path) -> SeedSnapshot: - seedinfo_path = directory / cls.SEED_INFO_FILENAME - py_rng_path = directory / cls.PY_RNG_TUPLE_FILENAME - np_rng_path = directory / cls.NP_RNG_STATE_FILENAME - torch_rng_path = directory / cls.TORCH_RNG_STATE_FILENAME - torch_cuda_rng_path = directory / cls.TORCH_CUDA_RNG_STATE_FILENAME - - # Load and set pythons rng - py_rng_state = tuple( - int(x) for x in np.fromfile(py_rng_path, dtype=cls.PY_RNG_STATE_DTYPE) - ) - np_rng_state = np.fromfile(np_rng_path, dtype=np.uint32) - seed_info = deserialize(seedinfo_path) - - torch_rng_path_exists = torch_rng_path.exists() - torch_cuda_rng_path_exists = torch_cuda_rng_path.exists() - - # By specifying `weights_only=True`, it disables arbitrary object loading - torch_rng_state = None - torch_cuda_rng = None - if torch_rng_path_exists or torch_cuda_rng_path_exists: - import torch + seedinfo_path = directory / cls.SEED_FILENAME - if torch_rng_path_exists: - # OPTIM: This ends up being much faster to go to numpy - _bytes = np.fromfile(torch_rng_path, dtype=np.uint8) - torch_rng_state = torch.tensor(_bytes, dtype=torch.uint8) - - if torch_cuda_rng_path_exists: - # By specifying `weights_only=True`, it disables arbitrary object loading - torch_cuda_rng = torch.load(torch_cuda_rng_path, weights_only=True) + with seedinfo_path.open("rb") as f: + seed_info = pickle.load(f) return SeedSnapshot( np_rng=( seed_info["np_rng_kind"], - np_rng_state, + seed_info["np_rng_state"], seed_info["np_pos"], seed_info["np_has_gauss"], seed_info["np_cached_gauss"], ), py_rng=( seed_info["py_rng_version"], - py_rng_state, + seed_info["py_rng_state"], seed_info["py_guass_next"], ), - torch_rng=torch_rng_state, - torch_cuda_rng=torch_cuda_rng, + torch_rng=seed_info.get("torch_rng"), + torch_cuda_rng=seed_info.get("torch_cuda_rng"), ) @classmethod def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: - seedinfo_path = directory / cls.SEED_INFO_FILENAME - py_rng_path = directory / cls.PY_RNG_TUPLE_FILENAME - np_rng_path = directory / cls.NP_RNG_STATE_FILENAME - torch_rng_path = directory / cls.TORCH_RNG_STATE_FILENAME - torch_cuda_rng_path = directory / cls.TORCH_CUDA_RNG_STATE_FILENAME - py_rng_version, py_rng_state, py_guass_next = snapshot.py_rng - - np.array(py_rng_state, dtype=cls.PY_RNG_STATE_DTYPE).tofile(py_rng_path) - seed_info = { "np_rng_kind": snapshot.np_rng[0], "np_pos": snapshot.np_rng[2], @@ -206,26 +176,18 @@ def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: "np_cached_gauss": snapshot.np_rng[4], "py_rng_version": py_rng_version, "py_guass_next": py_guass_next, + "py_rng_state": np.array(py_rng_state, dtype=cls.PY_RNG_STATE_DTYPE), + "np_rng_state": snapshot.np_rng[1], } - with atomic_write(seedinfo_path, "w") as f: - json.dump(seed_info, f) - - np_rng_state = snapshot.np_rng[1] - with atomic_write(np_rng_path, "wb") as f: - np_rng_state.tofile(f) - if snapshot.torch_rng is not None: - import torch - - # OPTIM: This ends up being much faster to go to numpy - with atomic_write(torch_rng_path, "wb") as f: - snapshot.torch_rng.numpy().tofile(f) + seed_info["torch_rng"] = snapshot.torch_rng.numpy() if snapshot.torch_cuda_rng is not None: - import torch + seed_info["torch_cuda_rng"] = snapshot.torch_cuda_rng - with atomic_write(torch_cuda_rng_path, "wb") as f: - torch.save(snapshot.torch_cuda_rng, f) + seedinfo_path = directory / cls.SEED_FILENAME + with seedinfo_path.open("wb") as f: + pickle.dump(seed_info, f, protocol=pickle.HIGHEST_PROTOCOL) @dataclass @@ -269,7 +231,7 @@ def read(cls, directory: Path) -> OptimizationState: @classmethod def write(cls, info: OptimizationState, directory: Path) -> None: info_path = directory / cls.STATE_FILE_NAME - with atomic_write(info_path, "w") as f: + with info_path.open("w") as f: json.dump(asdict(info), f) @@ -288,7 +250,7 @@ def read(cls, directory: Path) -> ErrDump: @classmethod def write(cls, err_dump: ErrDump, directory: Path) -> None: errors_path = directory / "errors.jsonl" - with atomic_write(errors_path, "w") as f: + with errors_path.open("w") as f: lines = [json.dumps(asdict(trial_err)) for trial_err in err_dump.errs] f.write("\n".join(lines)) @@ -329,7 +291,7 @@ def lock( timeout=self.timeout, flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, fail_when_locked=fail_if_locked, - ) as fh: + ): if worker_id is not None: logger.debug( "Worker %s acquired lock on %s at %s", @@ -339,8 +301,6 @@ def lock( ) yield - fh.flush() - os.fsync(fh) except pl.exceptions.LockException as e: raise pl.exceptions.LockException( f"Failed to acquire lock after timeout of {self.timeout} seconds." diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index e46d25d9..75174c22 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -10,6 +10,8 @@ from __future__ import annotations +import gc +import io import logging import pickle import time @@ -29,6 +31,7 @@ from neps.env import ( STATE_FILELOCK_POLL, STATE_FILELOCK_TIMEOUT, + TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION, TRIAL_FILELOCK_POLL, TRIAL_FILELOCK_TIMEOUT, ) @@ -46,11 +49,14 @@ from neps.state.optimizer import OptimizationState, OptimizerInfo from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Report, Trial +from neps.utils.files import atomic_write if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer logger = logging.getLogger(__name__) + + N_UNSAFE_RETRIES = 10 # TODO: Technically we don't need the same Location type for all shared objects. @@ -76,6 +82,7 @@ def make_sha() -> Version: @dataclass class TrialRepo: CACHE_FILE_NAME = ".trial_cache.pkl" + UPDATE_CONSOLIDATION_LIMIT = TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION directory: Path cache_path: Path = field(init=False) @@ -91,31 +98,70 @@ def list_trial_ids(self) -> list[str]: if config_path.name.startswith("config_") and config_path.is_dir() ] + def _read_pkl_and_maybe_consolidate( + self, + *, + consolidate: bool | None = None, + ) -> dict[str, Trial]: + with self.cache_path.open("rb") as f: + _bytes = f.read() + + buffer = io.BytesIO(_bytes) + try: + gc.disable() + trials = {} + updates = [] + while True: + try: + datum = pickle.load(buffer) + if isinstance(datum, dict): + assert len(trials) == 0, "Multiple caches present." + trials = datum + else: + assert isinstance(datum, Trial), "Not a trial." + updates.append(datum) + except EOFError: + break + + trials.update({trial.id: trial for trial in updates}) + if consolidate is True or ( + len(updates) > self.UPDATE_CONSOLIDATION_LIMIT and consolidate is None + ): + logger.debug( + "Consolidating trial cache with %d trials and %d updates.", + len(trials), + len(updates), + ) + with atomic_write(self.cache_path, "wb") as f: + pickle.dump(trials, f, protocol=pickle.HIGHEST_PROTOCOL) + + return trials + finally: + gc.enable() + def latest(self) -> dict[str, Trial]: if not self.cache_path.exists(): - # If we end up with no cache but there are trials on disk, we need to - # read them in. However we will not save back the cache here in fear of - # overwriting + # If we end up with no cache but there are trials on disk, we need to read in. if any(path.name.startswith("config_") for path in self.directory.iterdir()): trial_ids = self.list_trial_ids() - return { + trials = { trial_id: self.load_trial_from_disk(trial_id) for trial_id in trial_ids } + with atomic_write(self.cache_path, "wb") as f: + pickle.dump(trials, f, protocol=pickle.HIGHEST_PROTOCOL) return {} - with self.cache_path.open("rb") as f: - return pickle.load(f) # noqa: S301 + return self._read_pkl_and_maybe_consolidate() def new_trial(self, trial: Trial) -> None: config_path = self.directory / f"config_{trial.id}" if config_path.exists(): raise TrialAlreadyExistsError(trial.id, config_path) - trials = self.latest() - with self.cache_path.open("wb") as f: - trials[trial.id] = trial - pickle.dump(trials, f) + + with atomic_write(self.cache_path, "ab") as f: + pickle.dump(trial, f, protocol=pickle.HIGHEST_PROTOCOL) config_path.mkdir(parents=True, exist_ok=True) TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=None) @@ -126,10 +172,8 @@ def update_trial( *, hints: list[TrialWriteHint] | TrialWriteHint | None = None, ) -> None: - trials = self.latest() - with self.cache_path.open("wb") as f: - trials[trial.id] = trial - pickle.dump(trials, f) + with atomic_write(self.cache_path, "ab") as f: + pickle.dump(trial, f, protocol=pickle.HIGHEST_PROTOCOL) TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=hints) diff --git a/neps/utils/files.py b/neps/utils/files.py index 2ba9894e..d28faf44 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -3,12 +3,13 @@ from __future__ import annotations import dataclasses +import gc import os from collections.abc import Iterable, Iterator, Mapping from contextlib import contextmanager from enum import Enum from pathlib import Path -from typing import IO, Any +from typing import IO, Any, Literal import yaml @@ -60,24 +61,55 @@ def serializable_format(data: Any) -> Any: # noqa: PLR0911 return data -def serialize(data: Any, path: Path | str, *, sort_keys: bool = True) -> None: +def serialize( + data: Any, + path: Path | str, + *, + check_serialized: bool = True, + file_format: Literal["json", "yaml"] = "yaml", + sort_keys: bool = True, +) -> None: """Serialize data to a yaml file.""" - data = serializable_format(data) - path = Path(path) - with atomic_write(path, "w") as file_stream: - try: - return yaml.dump(data, file_stream, SafeDumper, sort_keys=sort_keys) - except yaml.representer.RepresenterError as e: - raise TypeError( - "Could not serialize to yaml! The object " - f"{e.args[1]} of type {type(e.args[1])} is not." - ) from e - + if check_serialized: + data = serializable_format(data) -def deserialize(path: Path | str) -> dict[str, Any]: + path = Path(path) + try: + gc.disable() + with path.open("w") as file_stream: + if file_format == "yaml": + try: + return yaml.dump(data, file_stream, SafeDumper, sort_keys=sort_keys) + except yaml.representer.RepresenterError as e: + raise TypeError( + "Could not serialize to yaml! The object " + f"{e.args[1]} of type {type(e.args[1])} is not." + ) from e + elif file_format == "json": + import json + + return json.dump(data, file_stream, sort_keys=sort_keys) + else: + raise ValueError(f"Unknown format: {file_format}") + finally: + gc.enable() + + +def deserialize( + path: Path | str, + *, + file_format: Literal["json", "yaml"] = "yaml", +) -> dict[str, Any]: """Deserialize data from a yaml file.""" with Path(path).open("r") as file_stream: - data = yaml.load(file_stream, SafeLoader) + if file_format == "json": + import json + + data = json.load(file_stream) + elif file_format == "yaml": + data = yaml.load(file_stream, SafeLoader) + else: + raise ValueError(f"Unknown format: {file_format}") if not isinstance(data, dict): raise TypeError( From c6a5003abf0a9d6cbe10cb07f6a6217ded9dcf5e Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 10 Dec 2024 18:02:11 +0100 Subject: [PATCH 51/56] optim: Bunch of micro-optimizations --- neps/env.py | 2 - neps/runtime.py | 20 +- neps/state/filebased.py | 261 ++++---------- neps/state/neps_state.py | 330 ++++++------------ neps/state/optimizer.py | 10 +- neps/state/trial.py | 19 +- neps/status/status.py | 4 +- neps/utils/cli.py | 18 +- neps/utils/common.py | 28 +- neps/utils/files.py | 49 ++- .../test_default_report_values.py | 20 +- .../test_error_handling_strategies.py | 19 +- tests/test_runtime/test_stopping_criterion.py | 17 +- tests/test_state/test_filebased_neps_state.py | 8 +- tests/test_state/test_neps_state.py | 7 +- tests/test_state/test_trial.py | 18 +- 16 files changed, 336 insertions(+), 494 deletions(-) diff --git a/neps/env.py b/neps/env.py index 01a5883b..637452c8 100644 --- a/neps/env.py +++ b/neps/env.py @@ -40,7 +40,6 @@ def yaml_or_json(e: str) -> Literal["yaml", "json"]: parse=str, default="lockf", ) - MAX_RETRIES_GET_NEXT_TRIAL = get_env( "NEPS_MAX_RETRIES_GET_NEXT_TRIAL", parse=int, @@ -96,7 +95,6 @@ def yaml_or_json(e: str) -> Literal["yaml", "json"]: parse=lambda e: None if is_nullable(e) else float(e), default=120, ) - GLOBAL_ERR_FILELOCK_POLL = get_env( "NEPS_GLOBAL_ERR_FILELOCK_POLL", parse=float, diff --git a/neps/runtime.py b/neps/runtime.py index 7f5b1ae5..47b08312 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -37,8 +37,10 @@ from neps.state._eval import evaluate_trial from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo +from neps.state.seed_snapshot import SeedSnapshot from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.state.trial import Trial +from neps.utils.common import gc_disabled if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer @@ -358,7 +360,9 @@ def _get_next_trial(self) -> Trial | Literal["break"]: # NOTE: It's important to release the trial lock before sampling # as otherwise, any other service, such as reporting the result # of a trial. Hence we do not lock these together with the above. - with self.state._trial_lock.lock(worker_id=self.worker_id): + # OPTIM: We try to prevent garbage collection from happening in here to + # minimize time spent holding on to the lock. + with self.state._trial_lock.lock(worker_id=self.worker_id), gc_disabled(): # Give the file-system some time to sync if we encountered out-of-order # issues with this worker. if self._GRACE > 0: @@ -375,7 +379,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: pending_trials = [ trial for trial in trials.values() - if trial.state == Trial.State.PENDING + if trial.metadata.state == Trial.State.PENDING ] if len(pending_trials) > 0: @@ -387,10 +391,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: time_started=time.time(), worker_id=self.worker_id, ) - self.state._trials.update_trial( - earliest_pending, - hints=["metadata", "state"], - ) + self.state._trials.update_trial(earliest_pending, hints="metadata") logger.info( "Worker '%s' picked up pending trial: %s.", self.worker_id, @@ -404,7 +405,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: trials=trials, ) - with self.state._trial_lock.lock(worker_id=self.worker_id): + with self.state._trial_lock.lock(worker_id=self.worker_id), gc_disabled(): try: sampled_trial.set_evaluating( time_started=time.time(), @@ -546,7 +547,7 @@ def run(self) -> None: # noqa: C901, PLR0915 "Worker '%s' evaluated trial: %s as %s.", self.worker_id, evaluated_trial.id, - evaluated_trial.state, + evaluated_trial.metadata.state, ) if report.cost is not None: @@ -610,6 +611,7 @@ def _launch_runtime( # noqa: PLR0913 load_only=False, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( + seed_snapshot=SeedSnapshot.new_capture(), budget=( BudgetInfo( max_cost_budget=max_cost_total, @@ -618,7 +620,7 @@ def _launch_runtime( # noqa: PLR0913 used_evaluations=0, ) ), - shared_state={}, # TODO: Unused for the time being... + shared_state=None, # TODO: Unused for the time being... ), ) break diff --git a/neps/state/filebased.py b/neps/state/filebased.py index 681c4d60..c43f533d 100644 --- a/neps/state/filebased.py +++ b/neps/state/filebased.py @@ -3,30 +3,27 @@ import contextlib import json import logging -import pickle import pprint import time -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path -from typing import ClassVar, Final, Literal, TypeAlias, TypeVar +from typing import Literal, TypeAlias, TypeVar -import numpy as np import portalocker as pl from neps.env import CONFIG_SERIALIZE_FORMAT, ENV_VARS_USED from neps.state.err_dump import ErrDump -from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo -from neps.state.seed_snapshot import SeedSnapshot from neps.state.trial import Trial +from neps.utils.common import gc_disabled from neps.utils.files import deserialize, serialize logger = logging.getLogger(__name__) K = TypeVar("K") T = TypeVar("T") -TrialWriteHint: TypeAlias = Literal["metadata", "report", "state", "config"] +TrialWriteHint: TypeAlias = Literal["metadata", "report", "config"] @dataclass @@ -42,23 +39,22 @@ class ReaderWriterTrial: # is much faster. METADATA_FILENAME = "metadata.json" - STATE_FILENAME = "state.txt" PREVIOUS_TRIAL_ID_FILENAME = "previous_trial_id.txt" @classmethod def read(cls, directory: Path) -> Trial: config_path = directory / cls.CONFIG_FILENAME metadata_path = directory / cls.METADATA_FILENAME - state_path = directory / cls.STATE_FILENAME report_path = directory / cls.REPORT_FILENAME with metadata_path.open("r") as f: metadata = json.load(f) + metadata["state"] = Trial.State(metadata["state"]) + return Trial( config=deserialize(config_path, file_format=CONFIG_SERIALIZE_FORMAT), metadata=Trial.MetaData(**metadata), - state=Trial.State(state_path.read_text(encoding="utf-8").strip()), report=( Trial.Report( **deserialize(report_path, file_format=CONFIG_SERIALIZE_FORMAT), @@ -74,165 +70,66 @@ def write( trial: Trial, directory: Path, *, - hints: list[TrialWriteHint] | TrialWriteHint | None = None, + hints: Iterable[TrialWriteHint] | TrialWriteHint | None = None, + _recurse: bool = False, ) -> None: config_path = directory / cls.CONFIG_FILENAME metadata_path = directory / cls.METADATA_FILENAME - state_path = directory / cls.STATE_FILENAME - - if isinstance(hints, str): - match hints: - case "config": - serialize( - trial.config, - config_path, - check_serialized=False, - file_format=CONFIG_SERIALIZE_FORMAT, - ) - case "metadata": - data = asdict(trial.metadata) - with metadata_path.open("w") as f: - json.dump(data, f) - - if trial.metadata.previous_trial_id is not None: - previous_trial_path = directory / cls.PREVIOUS_TRIAL_ID_FILENAME - previous_trial_path.write_text(trial.metadata.previous_trial_id) - case "report": - if trial.report is None: - raise ValueError( - "Cannot write report 'hint' when report is None." - ) - - report_path = directory / cls.REPORT_FILENAME - _report = asdict(trial.report) - if (err := _report.get("err")) is not None: - _report["err"] = str(err) - - serialize( - _report, - report_path, - check_serialized=False, - file_format=CONFIG_SERIALIZE_FORMAT, - ) - case "state": - state_path.write_text(trial.state.value, encoding="utf-8") - case _: - raise ValueError(f"Invalid hint: {hints}") - elif isinstance(hints, list): - for hint in hints: - cls.write(trial, directory, hints=hint) - elif hints is None: - # We don't know, write everything - cls.write(trial, directory, hints=["config", "metadata", "state"]) - - if trial.report is not None: - cls.write(trial, directory, hints="report") - else: - raise ValueError(f"Invalid hint: {hints}") - - -TrialReaderWriter: Final = ReaderWriterTrial() - - -@dataclass -class ReaderWriterSeedSnapshot: - """ReaderWriter for SeedSnapshot objects.""" - - # It seems like they're all uint32 but I can't be sure. - PY_RNG_STATE_DTYPE: ClassVar = np.int64 - SEED_FILENAME: ClassVar = "seed.pickle" - - @classmethod - def read(cls, directory: Path) -> SeedSnapshot: - seedinfo_path = directory / cls.SEED_FILENAME - - with seedinfo_path.open("rb") as f: - seed_info = pickle.load(f) - - return SeedSnapshot( - np_rng=( - seed_info["np_rng_kind"], - seed_info["np_rng_state"], - seed_info["np_pos"], - seed_info["np_has_gauss"], - seed_info["np_cached_gauss"], - ), - py_rng=( - seed_info["py_rng_version"], - seed_info["py_rng_state"], - seed_info["py_guass_next"], - ), - torch_rng=seed_info.get("torch_rng"), - torch_cuda_rng=seed_info.get("torch_cuda_rng"), - ) - - @classmethod - def write(cls, snapshot: SeedSnapshot, directory: Path) -> None: - py_rng_version, py_rng_state, py_guass_next = snapshot.py_rng - seed_info = { - "np_rng_kind": snapshot.np_rng[0], - "np_pos": snapshot.np_rng[2], - "np_has_gauss": snapshot.np_rng[3], - "np_cached_gauss": snapshot.np_rng[4], - "py_rng_version": py_rng_version, - "py_guass_next": py_guass_next, - "py_rng_state": np.array(py_rng_state, dtype=cls.PY_RNG_STATE_DTYPE), - "np_rng_state": snapshot.np_rng[1], - } - if snapshot.torch_rng is not None: - seed_info["torch_rng"] = snapshot.torch_rng.numpy() - - if snapshot.torch_cuda_rng is not None: - seed_info["torch_cuda_rng"] = snapshot.torch_cuda_rng - - seedinfo_path = directory / cls.SEED_FILENAME - with seedinfo_path.open("wb") as f: - pickle.dump(seed_info, f, protocol=pickle.HIGHEST_PROTOCOL) - -@dataclass -class ReaderWriterOptimizerInfo: - """ReaderWriter for OptimizerInfo objects.""" - - INFO_FILENAME: ClassVar = "info.yaml" - - @classmethod - def read(cls, directory: Path) -> OptimizerInfo: - info_path = directory / cls.INFO_FILENAME - return OptimizerInfo(info=deserialize(info_path)) - - @classmethod - def write(cls, optimizer_info: OptimizerInfo, directory: Path) -> None: - info_path = directory / cls.INFO_FILENAME - serialize(optimizer_info.info, info_path) - - -# TODO(eddiebergman): If an optimizer wants to store some hefty state, i.e. a numpy array -# or something, this is horribly inefficient and we would need to adapt OptimizerState to -# handle this. -# TODO(eddiebergman): May also want to consider serializing budget into a seperate entity -@dataclass -class ReaderWriterOptimizationState: - """ReaderWriter for OptimizationState objects.""" - - STATE_FILE_NAME: ClassVar = "state.json" - - @classmethod - def read(cls, directory: Path) -> OptimizationState: - state_path = directory / cls.STATE_FILE_NAME - with state_path.open("r") as f: - state = json.load(f) - - shared_state = state.get("shared_state") or {} - budget_info = state.get("budget") - budget = BudgetInfo(**budget_info) if budget_info is not None else None - return OptimizationState(shared_state=shared_state, budget=budget) + cm = contextlib.nullcontext if _recurse else gc_disabled + with cm(): + if isinstance(hints, str): + match hints: + case "config": + serialize( + trial.config, + config_path, + check_serialized=False, + file_format=CONFIG_SERIALIZE_FORMAT, + ) + case "metadata": + data = asdict(trial.metadata) + data["state"] = data["state"].value + with metadata_path.open("w") as f: + json.dump(data, f) + + if trial.metadata.previous_trial_id is not None: + previous_trial_path = ( + directory / cls.PREVIOUS_TRIAL_ID_FILENAME + ) + previous_trial_path.write_text( + trial.metadata.previous_trial_id + ) + case "report": + if trial.report is None: + raise ValueError( + "Cannot write report 'hint' when report is None." + ) + + report_path = directory / cls.REPORT_FILENAME + _report = asdict(trial.report) + if (err := _report.get("err")) is not None: + _report["err"] = str(err) + + serialize( + _report, + report_path, + check_serialized=False, + file_format=CONFIG_SERIALIZE_FORMAT, + ) + case _: + raise ValueError(f"Invalid hint: {hints}") + elif isinstance(hints, Iterable): + for hint in hints: + cls.write(trial, directory, hints=hint, _recurse=True) + elif hints is None: + # We don't know, write everything + cls.write(trial, directory, hints=["config", "metadata"], _recurse=True) - @classmethod - def write(cls, info: OptimizationState, directory: Path) -> None: - info_path = directory / cls.STATE_FILE_NAME - with info_path.open("w") as f: - json.dump(asdict(info), f) + if trial.report is not None: + cls.write(trial, directory, hints="report", _recurse=True) + else: + raise ValueError(f"Invalid hint: {hints}") @dataclass @@ -240,17 +137,18 @@ class ReaderWriterErrDump: """ReaderWriter for shared error lists.""" @classmethod - def read(cls, directory: Path) -> ErrDump: - errors_path = directory / "errors.jsonl" - with errors_path.open("r") as f: + def read(cls, path: Path) -> ErrDump: + if not path.exists(): + return ErrDump([]) + + with path.open("r") as f: data = [json.loads(line) for line in f] return ErrDump([ErrDump.SerializableTrialError(**d) for d in data]) @classmethod - def write(cls, err_dump: ErrDump, directory: Path) -> None: - errors_path = directory / "errors.jsonl" - with errors_path.open("w") as f: + def write(cls, err_dump: ErrDump, path: Path) -> None: + with path.open("w") as f: lines = [json.dumps(asdict(trial_err)) for trial_err in err_dump.errs] f.write("\n".join(lines)) @@ -274,24 +172,17 @@ class FileLocker: def __post_init__(self) -> None: self.lock_path = self.lock_path.resolve().absolute() + self._lock = pl.Lock( + self.lock_path, + check_interval=self.poll, + timeout=self.timeout, + flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, + ) @contextmanager - def lock( - self, - *, - fail_if_locked: bool = False, - worker_id: str | None = None, - ) -> Iterator[None]: - self.lock_path.parent.mkdir(parents=True, exist_ok=True) - + def lock(self, *, worker_id: str | None = None) -> Iterator[None]: try: - with pl.Lock( - self.lock_path, - check_interval=self.poll, - timeout=self.timeout, - flags=FILELOCK_EXCLUSIVE_NONE_BLOCKING, - fail_when_locked=fail_if_locked, - ): + with self._lock: if worker_id is not None: logger.debug( "Worker %s acquired lock on %s at %s", diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index 75174c22..e4cbd4b3 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -10,17 +10,15 @@ from __future__ import annotations -import gc import io import logging import pickle import time -from collections.abc import Callable +from collections.abc import Callable, Iterable from dataclasses import dataclass, field from pathlib import Path from typing import ( TYPE_CHECKING, - Generic, Literal, TypeAlias, TypeVar, @@ -29,6 +27,8 @@ from uuid import uuid4 from neps.env import ( + GLOBAL_ERR_FILELOCK_POLL, + GLOBAL_ERR_FILELOCK_TIMEOUT, STATE_FILELOCK_POLL, STATE_FILELOCK_TIMEOUT, TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION, @@ -40,19 +40,16 @@ from neps.state.filebased import ( FileLocker, ReaderWriterErrDump, - ReaderWriterOptimizationState, - ReaderWriterOptimizerInfo, - ReaderWriterSeedSnapshot, - TrialReaderWriter, + ReaderWriterTrial, TrialWriteHint, ) -from neps.state.optimizer import OptimizationState, OptimizerInfo -from neps.state.seed_snapshot import SeedSnapshot +from neps.state.optimizer import OptimizerInfo from neps.state.trial import Report, Trial -from neps.utils.files import atomic_write +from neps.utils.files import atomic_write, deserialize, serialize if TYPE_CHECKING: from neps.optimizers.base_optimizer import BaseOptimizer + from neps.state.optimizer import OptimizationState logger = logging.getLogger(__name__) @@ -107,37 +104,34 @@ def _read_pkl_and_maybe_consolidate( _bytes = f.read() buffer = io.BytesIO(_bytes) - try: - gc.disable() - trials = {} - updates = [] - while True: - try: - datum = pickle.load(buffer) - if isinstance(datum, dict): - assert len(trials) == 0, "Multiple caches present." - trials = datum - else: - assert isinstance(datum, Trial), "Not a trial." - updates.append(datum) - except EOFError: - break - - trials.update({trial.id: trial for trial in updates}) - if consolidate is True or ( - len(updates) > self.UPDATE_CONSOLIDATION_LIMIT and consolidate is None - ): - logger.debug( - "Consolidating trial cache with %d trials and %d updates.", - len(trials), - len(updates), - ) - with atomic_write(self.cache_path, "wb") as f: - pickle.dump(trials, f, protocol=pickle.HIGHEST_PROTOCOL) + trials = {} + updates = [] + while True: + try: + datum = pickle.load(buffer) # noqa: S301 + if isinstance(datum, dict): + assert len(trials) == 0, "Multiple caches present." + trials = datum + else: + assert isinstance(datum, Trial), "Not a trial." + updates.append(datum) + except EOFError: + break + + trials.update({trial.id: trial for trial in updates}) + if consolidate is True or ( + len(updates) > self.UPDATE_CONSOLIDATION_LIMIT and consolidate is None + ): + logger.debug( + "Consolidating trial cache with %d trials and %d updates.", + len(trials), + len(updates), + ) + pickle_bytes = pickle.dumps(trials, protocol=pickle.HIGHEST_PROTOCOL) + with atomic_write(self.cache_path, "wb") as f: + f.write(pickle_bytes) - return trials - finally: - gc.enable() + return trials def latest(self) -> dict[str, Trial]: if not self.cache_path.exists(): @@ -148,8 +142,9 @@ def latest(self) -> dict[str, Trial]: trial_id: self.load_trial_from_disk(trial_id) for trial_id in trial_ids } + pickle_bytes = pickle.dumps(trials, protocol=pickle.HIGHEST_PROTOCOL) with atomic_write(self.cache_path, "wb") as f: - pickle.dump(trials, f, protocol=pickle.HIGHEST_PROTOCOL) + f.write(pickle_bytes) return {} @@ -160,103 +155,35 @@ def new_trial(self, trial: Trial) -> None: if config_path.exists(): raise TrialAlreadyExistsError(trial.id, config_path) + bytes_ = pickle.dumps(trial, protocol=pickle.HIGHEST_PROTOCOL) with atomic_write(self.cache_path, "ab") as f: - pickle.dump(trial, f, protocol=pickle.HIGHEST_PROTOCOL) + f.write(bytes_) config_path.mkdir(parents=True, exist_ok=True) - TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=None) + ReaderWriterTrial.write( + trial, + self.directory / f"config_{trial.id}", + hints=["config", "metadata"], + ) def update_trial( self, trial: Trial, *, - hints: list[TrialWriteHint] | TrialWriteHint | None = None, + hints: Iterable[TrialWriteHint] | TrialWriteHint | None = ("report", "metadata"), ) -> None: + bytes_ = pickle.dumps(trial, protocol=pickle.HIGHEST_PROTOCOL) with atomic_write(self.cache_path, "ab") as f: - pickle.dump(trial, f, protocol=pickle.HIGHEST_PROTOCOL) + f.write(bytes_) - TrialReaderWriter.write(trial, self.directory / f"config_{trial.id}", hints=hints) + ReaderWriterTrial.write(trial, self.directory / f"config_{trial.id}", hints=hints) def load_trial_from_disk(self, trial_id: str) -> Trial: config_path = self.directory / f"config_{trial_id}" if not config_path.exists(): raise TrialNotFoundError(trial_id, config_path) - return TrialReaderWriter.read(config_path) - - -@dataclass -class VersionedResource(Generic[T]): - resource: T - path: Path - read: Callable[[Path], T] - write: Callable[[T, Path], None] - version_file: Path - version: Version = "__not_yet_written__" - - def latest(self) -> T: - if not self.version_file.exists(): - return self.resource - - file_version = self.version_file.read_text() - if self.version == file_version: - return self.resource - - self.resource = self.read(self.path) - self.version = file_version - return self.resource - - def update(self, new_resource: T) -> Version: - self.resource = new_resource - self.version = make_sha() - self.version_file.write_text(self.version) - self.write(new_resource, self.path) - return self.version - - @classmethod - def new( - cls, - resource: T, - path: Path, - read: Callable[[Path], T], - write: Callable[[T, Path], None], - version_file: Path, - ) -> VersionedResource[T]: - if version_file.exists(): - raise FileExistsError(f"Version file already exists at '{version_file}'.") - - write(resource, path) - version = make_sha() - version_file.write_text(version) - return cls( - resource=resource, - path=path, - read=read, - write=write, - version_file=version_file, - version=version, - ) - - @classmethod - def load( - cls, - path: Path, - *, - read: Callable[[Path], T], - write: Callable[[T, Path], None], - version_file: Path, - ) -> VersionedResource[T]: - if not path.exists(): - raise FileNotFoundError(f"Resource not found at '{path}'.") - - return cls( - resource=read(path), - path=path, - read=read, - write=write, - version_file=version_file, - version=version_file.read_text(), - ) + return ReaderWriterTrial.read(config_path) @dataclass @@ -269,12 +196,16 @@ class NePSState: _trials: TrialRepo = field(repr=False) _optimizer_lock: FileLocker = field(repr=False) - _optimizer_info: VersionedResource[OptimizerInfo] = field(repr=False) - _seed_snapshot: VersionedResource[SeedSnapshot] = field(repr=False) - _optimizer_state: VersionedResource[OptimizationState] = field(repr=False) + + _optimizer_info_path: Path = field(repr=False) + _optimizer_info: OptimizerInfo = field(repr=False) + + _optimizer_state_path: Path = field(repr=False) + _optimizer_state: OptimizationState = field(repr=False) _err_lock: FileLocker = field(repr=False) - _shared_errors: VersionedResource[ErrDump] = field(repr=False) + _shared_errors_path: Path = field(repr=False) + _shared_errors: ErrDump = field(repr=False) def lock_and_read_trials(self) -> dict[str, Trial]: """Acquire the state lock and read the trials.""" @@ -328,10 +259,10 @@ def _sample_trial( Returns: The new trial. """ - seed_state = self._seed_snapshot.latest() - opt_state = self._optimizer_state.latest() + with self._optimizer_state_path.open("rb") as f: + opt_state = pickle.load(f) # noqa: S301 - seed_state.set_as_global_seed_state() + opt_state.seed_snapshot.set_as_global_seed_state() # TODO: Not sure if any existing pre_load hooks required # it to be done after `load_results`... I hope not. @@ -339,32 +270,27 @@ def _sample_trial( for hook in _sample_hooks: optimizer = hook(optimizer) - # NOTE: Re-work this, as the part's that are recomputed - # do not need to be serialized - budget = opt_state.budget - if budget is not None: - budget = budget.clone() - + if opt_state.budget is not None: # NOTE: All other values of budget are ones that should remain # constant, there are currently only these two which are dynamic as # optimization unfold - budget.used_cost_budget = sum( + opt_state.budget.used_cost_budget = sum( trial.report.cost for trial in trials.values() if trial.report is not None and trial.report.cost is not None ) - budget.used_evaluations = len(trials) + opt_state.budget.used_evaluations = len(trials) sampled_config_maybe_new_opt_state = optimizer.ask( trials=trials, - budget_info=budget, + budget_info=opt_state.budget.clone(), ) if isinstance(sampled_config_maybe_new_opt_state, tuple): - sampled_config, new_opt_state = sampled_config_maybe_new_opt_state + sampled_config, shared_state = sampled_config_maybe_new_opt_state else: sampled_config = sampled_config_maybe_new_opt_state - new_opt_state = opt_state.shared_state + shared_state = opt_state.shared_state if sampled_config.previous_config_id is not None: previous_trial = trials.get(sampled_config.previous_config_id) @@ -385,11 +311,11 @@ def _sample_trial( time_sampled=time.time(), worker_id=worker_id, ) - seed_state.recapture() - self._seed_snapshot.update(seed_state) - self._optimizer_state.update( - OptimizationState(budget=opt_state.budget, shared_state=new_opt_state) - ) + + opt_state.shared_state = shared_state + opt_state.seed_snapshot.recapture() + with self._optimizer_state_path.open("wb") as f: + pickle.dump(opt_state, f, protocol=pickle.HIGHEST_PROTOCOL) return trial @@ -411,11 +337,11 @@ def _report_trial_evaluation( """ # IMPORTANT: We need to attach the report to the trial before updating the things. trial.report = report - self._trials.update_trial(trial, hints=["report", "metadata", "state"]) + self._trials.update_trial(trial, hints=["report", "metadata"]) if report.err is not None: with self._err_lock.lock(): - err_dump = self._shared_errors.latest() + err_dump = ReaderWriterErrDump.read(self._shared_errors_path) err_dump.errs.append( ErrDump.SerializableTrialError( trial_id=trial.id, @@ -425,7 +351,7 @@ def _report_trial_evaluation( tb=report.tb, ) ) - self._shared_errors.update(err_dump) + ReaderWriterErrDump.write(err_dump, self._shared_errors_path) def all_trial_ids(self) -> list[str]: """Get all the trial ids.""" @@ -434,17 +360,18 @@ def all_trial_ids(self) -> list[str]: def lock_and_get_errors(self) -> ErrDump: """Get all the errors that have occurred during the optimization.""" with self._err_lock.lock(): - return self._shared_errors.latest() + return ReaderWriterErrDump.read(self._shared_errors_path) def lock_and_get_optimizer_info(self) -> OptimizerInfo: """Get the optimizer information.""" with self._optimizer_lock.lock(): - return self._optimizer_info.latest() + return OptimizerInfo(info=deserialize(self._optimizer_info_path)) def lock_and_get_optimizer_state(self) -> OptimizationState: """Get the optimizer state.""" - with self._optimizer_lock.lock(): - return self._optimizer_state.latest() + with self._optimizer_lock.lock(): # noqa: SIM117 + with self._optimizer_state_path.open("rb") as f: + return pickle.load(f) # noqa: S301 def lock_and_get_trial_by_id(self, trial_id: str) -> Trial: """Get a trial by its id.""" @@ -503,7 +430,7 @@ def lock_and_get_next_pending_trial( [ trial for trial in trials.values() - if trial.state == Trial.State.PENDING + if trial.metadata.state == Trial.State.PENDING ], key=lambda t: t.metadata.time_sampled, ) @@ -557,14 +484,10 @@ def create_or_load( path.mkdir(parents=True, exist_ok=True) config_dir = path / "configs" config_dir.mkdir(parents=True, exist_ok=True) - seed_dir = path / ".seed_state" - seed_dir.mkdir(parents=True, exist_ok=True) - error_dir = path / ".errors" - error_dir.mkdir(parents=True, exist_ok=True) - optimizer_state_dir = path / ".optimizer_state" - optimizer_state_dir.mkdir(parents=True, exist_ok=True) - optimizer_info_dir = path / ".optimizer_info" - optimizer_info_dir.mkdir(parents=True, exist_ok=True) + + optimizer_info_path = path / "optimizer_info.yaml" + optimizer_state_path = path / "optimizer_state.pkl" + shared_errors_path = path / "shared_errors.jsonl" # We have to do one bit of sanity checking to ensure that the optimzier # info on disk manages the one we have recieved, otherwise we are unsure which @@ -575,70 +498,29 @@ def create_or_load( # check the optimizer info. If this assumption changes, then we would have # to first lock before we do this check if not is_new: - _optimizer_info = VersionedResource.load( - optimizer_info_dir, - read=ReaderWriterOptimizerInfo.read, - write=ReaderWriterOptimizerInfo.write, - version_file=optimizer_info_dir / ".version", - ) - _optimizer_state = VersionedResource.load( - optimizer_state_dir, - read=ReaderWriterOptimizationState.read, - write=ReaderWriterOptimizationState.write, - version_file=optimizer_state_dir / ".version", - ) - _seed_snapshot = VersionedResource.load( - seed_dir, - read=ReaderWriterSeedSnapshot.read, - write=ReaderWriterSeedSnapshot.write, - version_file=seed_dir / ".version", - ) - _shared_errors = VersionedResource.load( - error_dir, - read=ReaderWriterErrDump.read, - write=ReaderWriterErrDump.write, - version_file=error_dir / ".version", - ) - existing_info = _optimizer_info.latest() + existing_info = OptimizerInfo(info=deserialize(optimizer_info_path)) if not load_only and existing_info != optimizer_info: raise NePSError( "The optimizer info on disk does not match the one provided." f"\nOn disk: {existing_info}\nProvided: {optimizer_info}" - f"\n\nLoaded the one on disk from {optimizer_info_dir}." + f"\n\nLoaded the one on disk from {path}." ) + with optimizer_state_path.open("rb") as f: + optimizer_state = pickle.load(f) # noqa: S301 + + optimizer_info = existing_info + error_dump = ReaderWriterErrDump.read(shared_errors_path) else: assert optimizer_info is not None assert optimizer_state is not None - _optimizer_info = VersionedResource.new( - resource=optimizer_info, - path=optimizer_info_dir, - read=ReaderWriterOptimizerInfo.read, - write=ReaderWriterOptimizerInfo.write, - version_file=optimizer_info_dir / ".version", - ) - _optimizer_state = VersionedResource.new( - resource=optimizer_state, - path=optimizer_state_dir, - read=ReaderWriterOptimizationState.read, - write=ReaderWriterOptimizationState.write, - version_file=optimizer_state_dir / ".version", - ) - _seed_snapshot = VersionedResource.new( - resource=SeedSnapshot.new_capture(), - path=seed_dir, - read=ReaderWriterSeedSnapshot.read, - write=ReaderWriterSeedSnapshot.write, - version_file=seed_dir / ".version", - ) - _shared_errors = VersionedResource.new( - resource=ErrDump(), - path=error_dir, - read=ReaderWriterErrDump.read, - write=ReaderWriterErrDump.write, - version_file=error_dir / ".version", - ) - return cls( + serialize(optimizer_info.info, path=optimizer_info_path) + with optimizer_state_path.open("wb") as f: + pickle.dump(optimizer_state, f, protocol=pickle.HIGHEST_PROTOCOL) + + error_dump = ErrDump([]) + + return NePSState( path=path, _trials=TrialRepo(config_dir), # Locks, @@ -648,18 +530,20 @@ def create_or_load( timeout=TRIAL_FILELOCK_TIMEOUT, ), _optimizer_lock=FileLocker( - lock_path=path / ".state.lock", + lock_path=path / ".optimizer.lock", poll=STATE_FILELOCK_POLL, timeout=STATE_FILELOCK_TIMEOUT, ), _err_lock=FileLocker( - lock_path=error_dir / "errors.lock", - poll=TRIAL_FILELOCK_POLL, - timeout=TRIAL_FILELOCK_TIMEOUT, + lock_path=path / ".errors.lock", + poll=GLOBAL_ERR_FILELOCK_POLL, + timeout=GLOBAL_ERR_FILELOCK_TIMEOUT, ), # State - _optimizer_info=_optimizer_info, - _optimizer_state=_optimizer_state, - _seed_snapshot=_seed_snapshot, - _shared_errors=_shared_errors, + _optimizer_info_path=optimizer_info_path, + _optimizer_info=optimizer_info, + _optimizer_state_path=optimizer_state_path, + _optimizer_state=optimizer_state, # type: ignore + _shared_errors_path=shared_errors_path, + _shared_errors=error_dump, ) diff --git a/neps/state/optimizer.py b/neps/state/optimizer.py index 38d0bfb5..f149c67b 100644 --- a/neps/state/optimizer.py +++ b/neps/state/optimizer.py @@ -4,7 +4,10 @@ from collections.abc import Mapping from dataclasses import dataclass, replace -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from neps.state.seed_snapshot import SeedSnapshot @dataclass @@ -28,7 +31,10 @@ class OptimizationState: budget: BudgetInfo | None """Information regarind the budget used by the optimization trajectory.""" - shared_state: dict[str, Any] + seed_snapshot: SeedSnapshot + """The state of the random number generators at the time of the last sample.""" + + shared_state: dict[str, Any] | None """Any information the optimizer wants to store between calls to sample and post evaluations. diff --git a/neps/state/trial.py b/neps/state/trial.py index f89dd67e..260c728d 100644 --- a/neps/state/trial.py +++ b/neps/state/trial.py @@ -43,8 +43,10 @@ class MetaData: id: str location: str + state: State previous_trial_id: str | None previous_trial_location: str | None + sampling_worker_id: str time_sampled: float @@ -129,7 +131,6 @@ class Trial: config: Mapping[str, Any] metadata: MetaData - state: State report: Report | None @classmethod @@ -147,10 +148,10 @@ def new( """Create a new trial object that was just sampled.""" worker_id = str(worker_id) return cls( - state=State.PENDING, config=config, metadata=MetaData( id=trial_id, + state=State.PENDING, location=location, time_sampled=time_sampled, previous_trial_id=previous_trial, @@ -194,13 +195,13 @@ def into_config_result( def set_submitted(self, *, time_submitted: float) -> None: """Set the trial as submitted.""" self.metadata.time_submitted = time_submitted - self.state = State.SUBMITTED + self.metadata.state = State.SUBMITTED def set_evaluating(self, *, time_started: float, worker_id: int | str) -> None: """Set the trial as in progress.""" self.metadata.time_started = time_started self.metadata.evaluating_worker_id = str(worker_id) - self.state = State.EVALUATING + self.metadata.state = State.EVALUATING def set_complete( self, @@ -217,11 +218,11 @@ def set_complete( ) -> Report: """Set the report for the trial.""" if report_as == "success": - self.state = State.SUCCESS + self.metadata.state = State.SUCCESS elif report_as == "failed": - self.state = State.FAILED + self.metadata.state = State.FAILED elif report_as == "crashed": - self.state = State.CRASHED + self.metadata.state = State.CRASHED else: raise ValueError(f"Invalid report_as: '{report_as}'") @@ -249,13 +250,13 @@ def set_complete( def set_corrupted(self) -> None: """Set the trial as corrupted.""" - self.state = State.CORRUPTED + self.metadata.state = State.CORRUPTED def reset(self) -> None: """Reset the trial to a pending state.""" - self.state = State.PENDING self.metadata = MetaData( id=self.metadata.id, + state=State.PENDING, location=self.metadata.location, previous_trial_id=self.metadata.previous_trial_id, previous_trial_location=self.metadata.previous_trial_location, diff --git a/neps/status/status.py b/neps/status/status.py index b738b353..5caaa60c 100644 --- a/neps/status/status.py +++ b/neps/status/status.py @@ -62,12 +62,12 @@ def get_summary_dict( in_progress = { trial.id: trial.config for trial in trials.values() - if trial.State == Trial.State.EVALUATING + if trial.metadata.state == Trial.State.EVALUATING } pending = { trial.id: trial.config for trial in trials.values() - if trial.State == Trial.State.PENDING + if trial.metadata.state == Trial.State.PENDING } summary: dict[str, Any] = {} diff --git a/neps/utils/cli.py b/neps/utils/cli.py index 455fec43..b17b8d26 100644 --- a/neps/utils/cli.py +++ b/neps/utils/cli.py @@ -16,6 +16,7 @@ from typing import Optional, List import neps from neps.api import Default +from neps.state.seed_snapshot import SeedSnapshot from neps.status.status import post_run_csv import pandas as pd from neps.utils.run_args import ( @@ -140,12 +141,13 @@ def init_config(args: argparse.Namespace) -> None: path=directory, optimizer_info=OptimizerInfo(optimizer_info), optimizer_state=OptimizationState( + seed_snapshot=SeedSnapshot.new_capture(), budget=( BudgetInfo(max_cost_budget=max_cost_total, used_cost_budget=0) if max_cost_total is not None else None ), - shared_state={}, # TODO: Unused for the time being... + shared_state=None, # TODO: Unused for the time being... ), ) if is_new: @@ -338,7 +340,7 @@ def info_config(args: argparse.Namespace) -> None: print("Trial Information:") print(f" Trial ID: {trial.metadata.id}") - print(f" State: {trial.state}") + print(f" State: {trial.metadata.state}") print(f" Configurations:") for key, value in trial.config.items(): print(f" {key}: {value}") @@ -489,7 +491,7 @@ def status(args: argparse.Namespace) -> None: # Calculate the number of trials in different states trials = neps_state.lock_and_read_trials() evaluating_trials_count = sum( - 1 for trial in trials.values() if trial.state == Trial.State.EVALUATING + 1 for trial in trials.values() if trial.metadata.state == Trial.State.EVALUATING ) pending_trials_count = summary["num_pending_configs"] succeeded_trials_count = summary["num_evaluated_configs"] - summary["num_error"] @@ -516,15 +518,15 @@ def status(args: argparse.Namespace) -> None: # Filter trials based on state if args.pending: filtered_trials = [ - trial for trial in sorted_trials if trial.state.name == "PENDING" + trial for trial in sorted_trials if trial.metadata.state.name == "PENDING" ] elif args.evaluating: filtered_trials = [ - trial for trial in sorted_trials if trial.state.name == "EVALUATING" + trial for trial in sorted_trials if trial.metadata.state.name == "EVALUATING" ] elif args.succeeded: filtered_trials = [ - trial for trial in sorted_trials if trial.state.name == "SUCCESS" + trial for trial in sorted_trials if trial.metadata.state.name == "SUCCESS" ] else: filtered_trials = sorted_trials[:7] @@ -543,7 +545,7 @@ def status(args: argparse.Namespace) -> None: # Print the details of the filtered trials for trial in filtered_trials: time_sampled = convert_timestamp(trial.metadata.time_sampled) - if trial.state.name in ["PENDING", "EVALUATING"]: + if trial.metadata.state.name in ["PENDING", "EVALUATING"]: duration = compute_duration(trial.metadata.time_sampled) else: duration = ( @@ -553,7 +555,7 @@ def status(args: argparse.Namespace) -> None: ) trial_id = trial.id worker_id = trial.metadata.sampling_worker_id - state = trial.state.name + state = trial.metadata.state.name loss = ( f"{trial.report.loss:.6f}" if (trial.report and trial.report.loss is not None) diff --git a/neps/utils/common.py b/neps/utils/common.py index 3887565e..6fd4a10a 100644 --- a/neps/utils/common.py +++ b/neps/utils/common.py @@ -2,8 +2,10 @@ from __future__ import annotations +import gc import inspect -from collections.abc import Mapping, Sequence +from collections.abc import Iterator, Mapping, Sequence +from contextlib import contextmanager from functools import partial from pathlib import Path from typing import Any @@ -11,8 +13,6 @@ import torch import yaml -from neps.runtime import get_in_progress_trial, get_workers_neps_state - # TODO(eddiebergman): I feel like this function should throw an error if it can't # find anything to load, rather than returning None. In this case, we should provide @@ -35,6 +35,8 @@ def load_checkpoint( A dictionary containing the checkpoint values, or None if the checkpoint file does not exist hence no checkpointing was previously done. """ + from neps.runtime import get_in_progress_trial + if directory is None: trial = get_in_progress_trial() directory = trial.metadata.previous_trial_location @@ -75,6 +77,8 @@ def save_checkpoint( optimizer: The optimizer to save. checkpoint_name: The name of the checkpoint file. """ + from neps.runtime import get_in_progress_trial + if directory is None: in_progress_trial = get_in_progress_trial() directory = in_progress_trial.metadata.location @@ -113,6 +117,8 @@ def load_lightning_checkpoint( A tuple containing the checkpoint path (str) and the loaded checkpoint data (dict) or (None, None) if no checkpoint files are found in the directory. """ + from neps.runtime import get_in_progress_trial + if previous_pipeline_directory is None: trial = get_in_progress_trial() previous_pipeline_directory = trial.metadata.previous_trial_location @@ -156,6 +162,8 @@ def get_initial_directory(pipeline_directory: Path | str | None = None) -> Path: Returns: The initial directory. """ + from neps.runtime import get_in_progress_trial, get_workers_neps_state + neps_state = get_workers_neps_state() if pipeline_directory is not None: # TODO: Hard coded assumption @@ -333,3 +341,17 @@ def instance_from_map( # noqa: C901 raise TypeError(f"{e} when calling {instance} with {args_dict}") from e return instance + + +@contextmanager +def gc_disabled() -> Iterator[None]: + """Context manager to disable garbage collection for a block. + + We specifically put this around file I/O operations to minimize the time + spend garbage collecting while having the file handle open. + """ + gc.disable() + try: + yield + finally: + gc.enable() diff --git a/neps/utils/files.py b/neps/utils/files.py index d28faf44..585be708 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses -import gc +import io import os from collections.abc import Iterable, Iterator, Mapping from contextlib import contextmanager @@ -15,11 +15,12 @@ try: from yaml import ( - CSafeDumper as SafeDumper, # type: ignore + CDumper as YamlDumper, # type: ignore CSafeLoader as SafeLoader, # type: ignore ) -except ImportError: - from yaml import SafeDumper, SafeLoader # type: ignore +except ImportError as e: + raise ImportError() from e + from yaml import SafeLoader, YamlDumper # type: ignore @contextmanager @@ -63,7 +64,7 @@ def serializable_format(data: Any) -> Any: # noqa: PLR0911 def serialize( data: Any, - path: Path | str, + path: Path, *, check_serialized: bool = True, file_format: Literal["json", "yaml"] = "yaml", @@ -73,26 +74,24 @@ def serialize( if check_serialized: data = serializable_format(data) - path = Path(path) - try: - gc.disable() - with path.open("w") as file_stream: - if file_format == "yaml": - try: - return yaml.dump(data, file_stream, SafeDumper, sort_keys=sort_keys) - except yaml.representer.RepresenterError as e: - raise TypeError( - "Could not serialize to yaml! The object " - f"{e.args[1]} of type {type(e.args[1])} is not." - ) from e - elif file_format == "json": - import json - - return json.dump(data, file_stream, sort_keys=sort_keys) - else: - raise ValueError(f"Unknown format: {file_format}") - finally: - gc.enable() + buf = io.StringIO() + if file_format == "yaml": + try: + yaml.dump(data, buf, YamlDumper, sort_keys=sort_keys) + except yaml.representer.RepresenterError as e: + raise TypeError( + "Could not serialize to yaml! The object " + f"{e.args[1]} of type {type(e.args[1])} is not." + ) from e + elif file_format == "json": + import json + + json.dump(data, buf, sort_keys=sort_keys) + else: + raise ValueError(f"Unknown format: {file_format}") + + _str = buf.getvalue() + path.write_text(_str) def deserialize( diff --git a/tests/test_runtime/test_default_report_values.py b/tests/test_runtime/test_default_report_values.py index d857c69a..07c16f34 100644 --- a/tests/test_runtime/test_default_report_values.py +++ b/tests/test_runtime/test_default_report_values.py @@ -6,6 +6,7 @@ from neps.search_spaces.search_space import SearchSpace from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo +from neps.state.seed_snapshot import SeedSnapshot from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.search_spaces import Float from neps.state.trial import Trial @@ -16,7 +17,9 @@ def neps_state(tmp_path: Path) -> NePSState: return NePSState.create_or_load( path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), - optimizer_state=OptimizationState(budget=None, shared_state={}), + optimizer_state=OptimizationState( + budget=None, seed_snapshot=SeedSnapshot.new_capture(), shared_state={} + ), ) @@ -55,7 +58,8 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_crashed = sum( - trial.state == Trial.State.CRASHED is not None for trial in trials.values() + trial.metadata.state == Trial.State.CRASHED is not None + for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 @@ -64,7 +68,7 @@ def eval_function(*args, **kwargs) -> float: assert len(neps_state.lock_and_get_errors()) == 1 trial = trials.popitem()[1] - assert trial.state == Trial.State.CRASHED + assert trial.metadata.state == Trial.State.CRASHED assert trial.report is not None assert trial.report.loss == 2.4 assert trial.report.cost == 2.4 @@ -105,7 +109,8 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_sucess = sum( - trial.state == Trial.State.SUCCESS is not None for trial in trials.values() + trial.metadata.state == Trial.State.SUCCESS is not None + for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 @@ -114,7 +119,7 @@ def eval_function(*args, **kwargs) -> float: assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.report is not None assert trial.report.cost == 2.4 assert trial.report.learning_curve == [2.4, 2.5] @@ -153,7 +158,8 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_sucess = sum( - trial.state == Trial.State.SUCCESS is not None for trial in trials.values() + trial.metadata.state == Trial.State.SUCCESS is not None + for trial in trials.values() ) assert len(trials) == 1 assert n_sucess == 1 @@ -162,6 +168,6 @@ def eval_function(*args, **kwargs) -> float: assert len(neps_state.lock_and_get_errors()) == 0 trial = trials.popitem()[1] - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.report is not None assert trial.report.learning_curve == [LOSS] diff --git a/tests/test_runtime/test_error_handling_strategies.py b/tests/test_runtime/test_error_handling_strategies.py index d341aa2b..dee3e7eb 100644 --- a/tests/test_runtime/test_error_handling_strategies.py +++ b/tests/test_runtime/test_error_handling_strategies.py @@ -10,6 +10,7 @@ from neps.search_spaces.search_space import SearchSpace from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo +from neps.state.seed_snapshot import SeedSnapshot from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.search_spaces import Float from neps.state.trial import Trial @@ -20,7 +21,11 @@ def neps_state(tmp_path: Path) -> NePSState: return NePSState.create_or_load( path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), - optimizer_state=OptimizationState(budget=None, shared_state={}), + optimizer_state=OptimizationState( + budget=None, + seed_snapshot=SeedSnapshot.new_capture(), + shared_state=None, + ), ) @@ -61,7 +66,8 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_crashed = sum( - trial.state == Trial.State.CRASHED is not None for trial in trials.values() + trial.metadata.state == Trial.State.CRASHED is not None + for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 @@ -114,7 +120,8 @@ def evaler(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_crashed = sum( - trial.state == Trial.State.CRASHED is not None for trial in trials.values() + trial.metadata.state == Trial.State.CRASHED is not None + for trial in trials.values() ) assert len(trials) == 1 assert n_crashed == 1 @@ -184,10 +191,12 @@ def __call__(self, *args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() n_success = sum( - trial.state == Trial.State.SUCCESS is not None for trial in trials.values() + trial.metadata.state == Trial.State.SUCCESS is not None + for trial in trials.values() ) n_crashed = sum( - trial.state == Trial.State.CRASHED is not None for trial in trials.values() + trial.metadata.state == Trial.State.CRASHED is not None + for trial in trials.values() ) assert n_success == 1 assert n_crashed == 1 diff --git a/tests/test_runtime/test_stopping_criterion.py b/tests/test_runtime/test_stopping_criterion.py index 3e6da7ce..dd9b7bed 100644 --- a/tests/test_runtime/test_stopping_criterion.py +++ b/tests/test_runtime/test_stopping_criterion.py @@ -8,6 +8,7 @@ from neps.state.neps_state import NePSState from neps.state.neps_state import NePSState from neps.state.optimizer import OptimizationState, OptimizerInfo +from neps.state.seed_snapshot import SeedSnapshot from neps.state.settings import DefaultReportValues, OnErrorPossibilities, WorkerSettings from neps.search_spaces import Float from neps.state.trial import Trial @@ -18,7 +19,11 @@ def neps_state(tmp_path: Path) -> NePSState: return NePSState.create_or_load( path=tmp_path / "neps_state", optimizer_info=OptimizerInfo(info={"nothing": "here"}), - optimizer_state=OptimizationState(budget=None, shared_state={}), + optimizer_state=OptimizationState( + budget=None, + seed_snapshot=SeedSnapshot.new_capture(), + shared_state=None, + ), ) @@ -57,7 +62,7 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() for _, trial in trials.items(): - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.report is not None assert trial.report.loss == 1.0 @@ -111,7 +116,7 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() assert len(trials) == 2 for _, trial in trials.items(): - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.report is not None assert trial.report.loss == 1.0 @@ -132,7 +137,7 @@ def eval_function(*args, **kwargs) -> float: trials = neps_state.lock_and_read_trials() assert len(trials) == 4 # Now we should have 4 of them for _, trial in trials.items(): - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.report is not None assert trial.report.loss == 1.0 @@ -182,13 +187,13 @@ def eval_function(*args, **kwargs) -> float: the_pending_trial = trials[pending_trial.id] assert the_pending_trial == pending_trial - assert the_pending_trial.state == Trial.State.EVALUATING + assert the_pending_trial.metadata.state == Trial.State.EVALUATING assert the_pending_trial.report is None the_completed_trial_id = next(iter(trials.keys() - {pending_trial.id})) the_completed_trial = trials[the_completed_trial_id] - assert the_completed_trial.state == Trial.State.SUCCESS + assert the_completed_trial.metadata.state == Trial.State.SUCCESS assert the_completed_trial.report is not None assert the_completed_trial.report.loss == 1.0 diff --git a/tests/test_state/test_filebased_neps_state.py b/tests/test_state/test_filebased_neps_state.py index 87085639..4ad4ee43 100644 --- a/tests/test_state/test_filebased_neps_state.py +++ b/tests/test_state/test_filebased_neps_state.py @@ -11,6 +11,7 @@ import pytest from pytest_cases import fixture, parametrize from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo +from neps.state.seed_snapshot import SeedSnapshot @fixture @@ -20,7 +21,11 @@ def optimizer_state( budget: BudgetInfo | None, shared_state: dict[str, Any], ) -> OptimizationState: - return OptimizationState(budget=budget, shared_state=shared_state) + return OptimizationState( + budget=budget, + seed_snapshot=SeedSnapshot.new_capture(), + shared_state=shared_state, + ) @fixture @@ -69,6 +74,7 @@ def test_create_or_load_with_load_filebased_neps_state( # was passed in. different_state = OptimizationState( budget=BudgetInfo(max_cost_budget=20, used_cost_budget=10), + seed_snapshot=SeedSnapshot.new_capture(), shared_state={"c": "d"}, ) neps_state2 = NePSState.create_or_load( diff --git a/tests/test_state/test_neps_state.py b/tests/test_state/test_neps_state.py index 78b3213b..dd8da3d3 100644 --- a/tests/test_state/test_neps_state.py +++ b/tests/test_state/test_neps_state.py @@ -21,6 +21,7 @@ from neps.state.neps_state import NePSState from neps.state.optimizer import BudgetInfo, OptimizationState, OptimizerInfo from neps.optimizers import SearcherMapping +from neps.state.seed_snapshot import SeedSnapshot @case @@ -157,7 +158,11 @@ def case_neps_state_filebased( return NePSState.create_or_load( path=new_path, optimizer_info=optimizer_info, - optimizer_state=OptimizationState(budget=budget, shared_state=shared_state), + optimizer_state=OptimizationState( + budget=budget, + seed_snapshot=SeedSnapshot.new_capture(), + shared_state=shared_state, + ), ) diff --git a/tests/test_state/test_trial.py b/tests/test_state/test_trial.py index a433a917..c2acf3b9 100644 --- a/tests/test_state/test_trial.py +++ b/tests/test_state/test_trial.py @@ -20,12 +20,13 @@ def test_trial_creation() -> None: previous_trial=previous_trial, worker_id=worker_id, ) - assert trial.state == Trial.State.PENDING + assert trial.metadata.state == Trial.State.PENDING assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id="1", time_sampled=time_sampled, + state=Trial.State.PENDING, location="1", previous_trial_location=None, previous_trial_id=previous_trial, @@ -54,11 +55,12 @@ def test_trial_as_submitted() -> None: ) trial.set_submitted(time_submitted=time_submitted) - assert trial.state == Trial.State.SUBMITTED + assert trial.metadata.state == Trial.State.SUBMITTED assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id=trial_id, + state=Trial.State.SUBMITTED, time_sampled=time_sampled, previous_trial_location="0", location="1", @@ -91,11 +93,12 @@ def test_trial_as_in_progress_with_different_evaluating_worker() -> None: trial.set_submitted(time_submitted=time_submitted) trial.set_evaluating(time_started=time_started, worker_id=evaluating_worker_id) - assert trial.state == Trial.State.EVALUATING + assert trial.metadata.state == Trial.State.EVALUATING assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id=trial_id, + state=Trial.State.EVALUATING, time_sampled=time_sampled, previous_trial_id=previous_trial, previous_trial_location="0", @@ -144,11 +147,12 @@ def test_trial_as_success_after_being_progress() -> None: time_end=time_end, ) - assert trial.state == Trial.State.SUCCESS + assert trial.metadata.state == Trial.State.SUCCESS assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id=trial_id, + state=Trial.State.SUCCESS, time_sampled=time_sampled, previous_trial_location="0", location="1", @@ -208,11 +212,12 @@ def test_trial_as_failed_with_nan_loss_and_in_cost() -> None: extra=extra, time_end=time_end, ) - assert trial.state == Trial.State.FAILED + assert trial.metadata.state == Trial.State.FAILED assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id=trial_id, + state=Trial.State.FAILED, time_sampled=time_sampled, previous_trial_id=previous_trial, sampling_worker_id=sampling_worker_id, @@ -273,11 +278,12 @@ def test_trial_as_crashed_with_err_and_tb() -> None: time_end=time_end, ) - assert trial.state == Trial.State.CRASHED + assert trial.metadata.state == Trial.State.CRASHED assert trial.id == trial_id assert trial.config == {"a": "b"} assert trial.metadata == Trial.MetaData( id=trial_id, + state=Trial.State.CRASHED, time_sampled=time_sampled, previous_trial_id=previous_trial, sampling_worker_id=sampling_worker_id, From 93c8d319b473adb158f0fc817c33937ee08a7dd3 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 10 Dec 2024 18:02:55 +0100 Subject: [PATCH 52/56] fix: Remove dummy check --- neps/utils/files.py | 1 - 1 file changed, 1 deletion(-) diff --git a/neps/utils/files.py b/neps/utils/files.py index 585be708..b54a6a3b 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -19,7 +19,6 @@ CSafeLoader as SafeLoader, # type: ignore ) except ImportError as e: - raise ImportError() from e from yaml import SafeLoader, YamlDumper # type: ignore From a3610bc909d5a00056d3951d4b1b29a280e76919 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Tue, 10 Dec 2024 18:07:31 +0100 Subject: [PATCH 53/56] optim: More aggresive consolidation of cache --- neps/env.py | 2 +- neps/utils/files.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/neps/env.py b/neps/env.py index 637452c8..07b6c86c 100644 --- a/neps/env.py +++ b/neps/env.py @@ -108,7 +108,7 @@ def yaml_or_json(e: str) -> Literal["yaml", "json"]: TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION = get_env( "NEPS_TRIAL_CACHE_MAX_UPDATES_BEFORE_CONSOLIDATION", parse=int, - default=30, + default=10, ) CONFIG_SERIALIZE_FORMAT: Literal["yaml", "json"] = get_env( # type: ignore "NEPS_CONFIG_SERIALIZE_FORMAT", diff --git a/neps/utils/files.py b/neps/utils/files.py index b54a6a3b..face3fca 100644 --- a/neps/utils/files.py +++ b/neps/utils/files.py @@ -18,7 +18,7 @@ CDumper as YamlDumper, # type: ignore CSafeLoader as SafeLoader, # type: ignore ) -except ImportError as e: +except ImportError: from yaml import SafeLoader, YamlDumper # type: ignore From 7a749b1cbc33895bf42bdcba618d97991e46d87b Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Wed, 11 Dec 2024 10:04:08 +0100 Subject: [PATCH 54/56] feat: Pass in batch size --- neps/api.py | 7 +- neps/exceptions.py | 33 ++-- neps/optimizers/base_optimizer.py | 21 ++- .../bayesian_optimization/optimizer.py | 59 ++++-- neps/optimizers/random_search/optimizer.py | 23 ++- neps/runtime.py | 55 ++++-- neps/state/neps_state.py | 171 +++++++++++++----- neps/state/settings.py | 3 + neps/utils/run_args.py | 1 + 9 files changed, 263 insertions(+), 110 deletions(-) diff --git a/neps/api.py b/neps/api.py index a140a7f6..487425f3 100644 --- a/neps/api.py +++ b/neps/api.py @@ -44,6 +44,7 @@ def run( loss_value_on_error: None | float = Default(None), cost_value_on_error: None | float = Default(None), pre_load_hooks: Iterable | None = Default(None), + sample_batch_size: int | None = Default(None), searcher: ( Literal[ "default", @@ -98,6 +99,8 @@ def run( cost_value_on_error: Setting this and loss_value_on_error to any float will supress any error and will use given cost value instead. default: None pre_load_hooks: List of functions that will be called before load_results(). + sample_batch_size: The number of samples to ask for in a single call to the + optimizer. searcher: Which optimizer to use. Can be a string identifier, an instance of BaseOptimizer, or a Path to a custom optimizer. **searcher_kwargs: Will be passed to the searcher. This is usually only needed by @@ -236,6 +239,7 @@ def run( ignore_errors=settings.ignore_errors, overwrite_optimization_dir=settings.overwrite_working_directory, pre_load_hooks=settings.pre_load_hooks, + sample_batch_size=settings.sample_batch_size, ) if settings.post_run_summary: @@ -278,7 +282,8 @@ def _run_args( "mobster", "asha", ] - | BaseOptimizer | dict + | BaseOptimizer + | dict ) = "default", **searcher_kwargs, ) -> tuple[BaseOptimizer, dict]: diff --git a/neps/exceptions.py b/neps/exceptions.py index bcfe198f..67ed428f 100644 --- a/neps/exceptions.py +++ b/neps/exceptions.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + class NePSError(Exception): """Base class for all NePS exceptions. @@ -11,35 +13,22 @@ class NePSError(Exception): """ -class VersionMismatchError(NePSError): - """Raised when the version of a resource does not match the expected version.""" - - -class VersionedResourceAlreadyExistsError(NePSError): - """Raised when a version already exists when trying to create a new versioned - data. - """ - - -class VersionedResourceRemovedError(NePSError): - """Raised when a version already exists when trying to create a new versioned - data. - """ - - -class VersionedResourceDoesNotExistError(NePSError): - """Raised when a versioned resource does not exist at a location.""" - - class LockFailedError(NePSError): """Raised when a lock cannot be acquired.""" -class TrialAlreadyExistsError(VersionedResourceAlreadyExistsError): +class TrialAlreadyExistsError(NePSError): """Raised when a trial already exists in the store.""" + def __init__(self, trial_id: str, *args: Any) -> None: + super().__init__(trial_id, *args) + self.trial_id = trial_id + + def __str__(self) -> str: + return f"Trial with id {self.trial_id} already exists!" + -class TrialNotFoundError(VersionedResourceDoesNotExistError): +class TrialNotFoundError(NePSError): """Raised when a trial already exists in the store.""" diff --git a/neps/optimizers/base_optimizer.py b/neps/optimizers/base_optimizer.py index 898508c7..fad50724 100644 --- a/neps/optimizers/base_optimizer.py +++ b/neps/optimizers/base_optimizer.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections.abc import Mapping from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload from neps.state.trial import Report, Trial @@ -106,12 +106,29 @@ def __init__( self.learning_curve_on_error = learning_curve_on_error self.ignore_errors = ignore_errors + @overload + def ask( + self, + trials: Mapping[str, Trial], + budget_info: BudgetInfo | None, + n: int, + ) -> list[SampledConfig]: ... + + @overload + def ask( + self, + trials: Mapping[str, Trial], + budget_info: BudgetInfo | None, + n: None = None, + ) -> SampledConfig: ... + @abstractmethod def ask( self, trials: Mapping[str, Trial], budget_info: BudgetInfo | None, - ) -> SampledConfig: + n: int | None = None, + ) -> SampledConfig | list[SampledConfig]: """Sample a new configuration. Args: diff --git a/neps/optimizers/bayesian_optimization/optimizer.py b/neps/optimizers/bayesian_optimization/optimizer.py index bf491e1c..4c8aa1f9 100644 --- a/neps/optimizers/bayesian_optimization/optimizer.py +++ b/neps/optimizers/bayesian_optimization/optimizer.py @@ -127,35 +127,58 @@ def __init__( self.cost_on_log_scale = cost_on_log_scale self.device = device self.sample_default_first = sample_default_first - self.n_initial_design = initial_design_size self.init_design: list[dict[str, Any]] | None = None + if initial_design_size is not None: + self.n_initial_design = initial_design_size + else: + self.n_initial_design = len(pipeline_space.numerical) + len( + pipeline_space.categoricals + ) + @override def ask( self, trials: Mapping[str, Trial], budget_info: BudgetInfo | None = None, - ) -> SampledConfig: + n: int | None = None, + ) -> SampledConfig | list[SampledConfig]: + _n = 1 if n is None else n n_sampled = len(trials) - config_id = str(n_sampled + 1) + config_ids = iter(str(n_sampled + i) for i in range(_n + 1)) space = self.pipeline_space # If we havn't passed the intial design phase if self.init_design is None: + if n is not None and self.n_initial_design < n: + init_design_size = n + else: + init_design_size = self.n_initial_design self.init_design = make_initial_design( space=space, encoder=self.encoder, sample_default_first=self.sample_default_first, sampler=self.prior if self.prior is not None else "sobol", seed=None, # TODO: Seeding - sample_size=( - "ndim" if self.n_initial_design is None else self.n_initial_design - ), + sample_size=init_design_size, sample_fidelity="max", ) - if n_sampled < len(self.init_design): - return SampledConfig(id=config_id, config=self.init_design[n_sampled]) + sampled_configs = [ + SampledConfig(id=config_id, config=config) + for config_id, config in zip( + config_ids, self.init_design[n_sampled : n_sampled + _n], strict=False + ) + ] + if len(sampled_configs) == _n: + if n is None: + return sampled_configs[0] + + return sampled_configs + + assert len(sampled_configs) < _n + + _n = _n - len(sampled_configs) # Otherwise, we encode trials and setup to fit and acquire from a GP data, encoder = encode_trials_for_gp( @@ -185,7 +208,7 @@ def ask( prior = None if pibo_exp_term < 1e-4 else self.prior gp = make_default_single_obj_gp(x=data.x, y=data.y, encoder=encoder) - candidate = fit_and_acquire_from_gp( + candidates = fit_and_acquire_from_gp( gp=gp, x_train=data.x, encoder=encoder, @@ -200,11 +223,25 @@ def ask( prune_baseline=True, ), prior=prior, + n_candidates_required=_n, pibo_exp_term=pibo_exp_term, costs=data.cost if self.use_cost else None, cost_percentage_used=cost_percent, costs_on_log_scale=self.cost_on_log_scale, ) - config = encoder.decode(candidate)[0] - return SampledConfig(id=config_id, config=config) + config_ids = list(config_ids) + print(_n, len(candidates), len(config_ids)) # noqa: T201 + + configs = encoder.decode(candidates) + sampled_configs.extend( + [ + SampledConfig(id=config_id, config=config) + for config_id, config in zip(config_ids, configs, strict=True) + ] + ) + + if n is None: + return sampled_configs[0] + + return sampled_configs diff --git a/neps/optimizers/random_search/optimizer.py b/neps/optimizers/random_search/optimizer.py index 8bcc8178..a5df59ad 100644 --- a/neps/optimizers/random_search/optimizer.py +++ b/neps/optimizers/random_search/optimizer.py @@ -56,9 +56,22 @@ def ask( self, trials: Mapping[str, Trial], budget_info: BudgetInfo | None, - ) -> SampledConfig: + n: int | None = None, + ) -> SampledConfig | list[SampledConfig]: n_trials = len(trials) - config = self.sampler.sample_one(to=self.encoder.domains) - config_dict = self.encoder.decode_one(config) - config_id = str(n_trials + 1) - return SampledConfig(config=config_dict, id=config_id, previous_config_id=None) + _n = 1 if n is None else n + configs = self.sampler.sample(_n, to=self.encoder.domains) + config_dicts = self.encoder.decode(configs) + if n == 1: + config = config_dicts[0] + config_id = str(n_trials + 1) + return SampledConfig(config=config, id=config_id, previous_config_id=None) + + return [ + SampledConfig( + config=config, + id=str(n_trials + i + 1), + previous_config_id=None, + ) + for i, config in enumerate(config_dicts) + ] diff --git a/neps/runtime.py b/neps/runtime.py index 47b08312..6ea490b6 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -300,10 +300,12 @@ def _check_global_stopping_criterion( ) -> str | Literal[False]: if self.settings.max_evaluations_total is not None: if self.settings.include_in_progress_evaluations_towards_maximum: - # NOTE: We can just use the sum of trials in this case as they - # either have a report, are pending or being evaluated. There - # are also crashed and unknown states which we include into this. - count = len(trials) + count = sum( + 1 + for _, trial in trials.items() + if trial.metadata.state + not in (Trial.State.PENDING, Trial.State.SUBMITTED) + ) else: # This indicates they have completed. count = sum(1 for _, trial in trials.items() if trial.report is not None) @@ -399,32 +401,45 @@ def _get_next_trial(self) -> Trial | Literal["break"]: ) return earliest_pending - sampled_trial = self.state._sample_trial( + sampled_trials = self.state._sample_trial( optimizer=self.optimizer, worker_id=self.worker_id, trials=trials, + n=self.settings.batch_size, ) + if isinstance(sampled_trials, Trial): + this_workers_trial = sampled_trials + else: + this_workers_trial = sampled_trials[0] + sampled_trials[1:] with self.state._trial_lock.lock(worker_id=self.worker_id), gc_disabled(): + this_workers_trial.set_evaluating( + time_started=time.time(), + worker_id=self.worker_id, + ) try: - sampled_trial.set_evaluating( - time_started=time.time(), - worker_id=self.worker_id, - ) - self.state._trials.new_trial(sampled_trial) - logger.info( - "Worker '%s' sampled new trial: %s.", - self.worker_id, - sampled_trial.id, - ) - return sampled_trial + self.state._trials.new_trial(sampled_trials) + if isinstance(sampled_trials, Trial): + logger.info( + "Worker '%s' sampled new trial: %s.", + self.worker_id, + this_workers_trial.id, + ) + else: + logger.info( + "Worker '%s' sampled new trials: %s.", + self.worker_id, + ",".join(trial.id for trial in sampled_trials), + ) + return this_workers_trial except TrialAlreadyExistsError as e: - if sampled_trial.id in trials: + if e.trial_id in trials: logger.error( "The new sampled trial was given an id of '%s', yet this" " exists in the loaded in trials given to the optimizer. This" " indicates a bug with the optimizers allocation of ids.", - sampled_trial.id, + e.trial_id, ) else: _grace = DefaultWorker._GRACE @@ -439,7 +454,7 @@ def _get_next_trial(self) -> Trial | Literal["break"]: " '%s's to '%s's. You can control the initial" " grace with 'NEPS_FS_SYNC_GRACE_BASE' and the increment with" " 'NEPS_FS_SYNC_GRACE_INC'.", - sampled_trial.id, + e.trial_id, _grace, _grace + _inc, ) @@ -595,6 +610,7 @@ def _launch_runtime( # noqa: PLR0913 overwrite_optimization_dir: bool, max_evaluations_total: int | None, max_evaluations_for_worker: int | None, + sample_batch_size: int | None, pre_load_hooks: Iterable[Callable[[BaseOptimizer], BaseOptimizer]] | None, ) -> None: if overwrite_optimization_dir and optimization_dir.exists(): @@ -643,6 +659,7 @@ def _launch_runtime( # noqa: PLR0913 if ignore_errors else OnErrorPossibilities.RAISE_ANY_ERROR ), + batch_size=sample_batch_size, default_report_values=DefaultReportValues( loss_value_on_error=loss_value_on_error, cost_value_on_error=cost_value_on_error, diff --git a/neps/state/neps_state.py b/neps/state/neps_state.py index e4cbd4b3..e428d94b 100644 --- a/neps/state/neps_state.py +++ b/neps/state/neps_state.py @@ -36,6 +36,7 @@ TRIAL_FILELOCK_TIMEOUT, ) from neps.exceptions import NePSError, TrialAlreadyExistsError, TrialNotFoundError +from neps.optimizers.base_optimizer import BaseOptimizer from neps.state.err_dump import ErrDump from neps.state.filebased import ( FileLocker, @@ -48,7 +49,6 @@ from neps.utils.files import atomic_write, deserialize, serialize if TYPE_CHECKING: - from neps.optimizers.base_optimizer import BaseOptimizer from neps.state.optimizer import OptimizationState logger = logging.getLogger(__name__) @@ -112,6 +112,8 @@ def _read_pkl_and_maybe_consolidate( if isinstance(datum, dict): assert len(trials) == 0, "Multiple caches present." trials = datum + elif isinstance(datum, list): + updates.extend(datum) else: assert isinstance(datum, Trial), "Not a trial." updates.append(datum) @@ -150,21 +152,39 @@ def latest(self) -> dict[str, Trial]: return self._read_pkl_and_maybe_consolidate() - def new_trial(self, trial: Trial) -> None: - config_path = self.directory / f"config_{trial.id}" - if config_path.exists(): - raise TrialAlreadyExistsError(trial.id, config_path) - - bytes_ = pickle.dumps(trial, protocol=pickle.HIGHEST_PROTOCOL) - with atomic_write(self.cache_path, "ab") as f: - f.write(bytes_) - - config_path.mkdir(parents=True, exist_ok=True) - ReaderWriterTrial.write( - trial, - self.directory / f"config_{trial.id}", - hints=["config", "metadata"], - ) + def new_trial(self, trial: Trial | list[Trial]) -> None: + if isinstance(trial, Trial): + config_path = self.directory / f"config_{trial.id}" + if config_path.exists(): + raise TrialAlreadyExistsError(trial.id, config_path) + + bytes_ = pickle.dumps(trial, protocol=pickle.HIGHEST_PROTOCOL) + with atomic_write(self.cache_path, "ab") as f: + f.write(bytes_) + + config_path.mkdir(parents=True, exist_ok=True) + ReaderWriterTrial.write( + trial, + self.directory / f"config_{trial.id}", + hints=["config", "metadata"], + ) + else: + for child_trial in trial: + config_path = self.directory / f"config_{child_trial.id}" + if config_path.exists(): + raise TrialAlreadyExistsError(child_trial.id, config_path) + config_path.mkdir(parents=True, exist_ok=True) + + bytes_ = pickle.dumps(trial, protocol=pickle.HIGHEST_PROTOCOL) + with atomic_write(self.cache_path, "ab") as f: + f.write(bytes_) + + for child_trial in trial: + ReaderWriterTrial.write( + child_trial, + self.directory / f"config_{child_trial.id}", + hints=["config", "metadata"], + ) def update_trial( self, @@ -181,7 +201,9 @@ def update_trial( def load_trial_from_disk(self, trial_id: str) -> Trial: config_path = self.directory / f"config_{trial_id}" if not config_path.exists(): - raise TrialNotFoundError(trial_id, config_path) + raise TrialNotFoundError( + f"Trial {trial_id} not found at expected path of {config_path}." + ) return ReaderWriterTrial.read(config_path) @@ -212,18 +234,34 @@ def lock_and_read_trials(self) -> dict[str, Trial]: with self._trial_lock.lock(): return self._trials.latest() - def lock_and_sample_trial(self, optimizer: BaseOptimizer, *, worker_id: str) -> Trial: + @overload + def lock_and_sample_trial( + self, optimizer: BaseOptimizer, *, worker_id: str, n: None = None + ) -> Trial: ... + @overload + def lock_and_sample_trial( + self, optimizer: BaseOptimizer, *, worker_id: str, n: int + ) -> list[Trial]: ... + + def lock_and_sample_trial( + self, optimizer: BaseOptimizer, *, worker_id: str, n: int | None = None + ) -> Trial | list[Trial]: """Acquire the state lock and sample a trial.""" with self._optimizer_lock.lock(): with self._trial_lock.lock(): trials = self._trials.latest() - trial = self._sample_trial(optimizer, trials=trials, worker_id=worker_id) + trials = self._sample_trial( + optimizer, + trials=trials, + worker_id=worker_id, + n=n, + ) with self._trial_lock.lock(): - self._trials.new_trial(trial) + self._trials.new_trial(trials) - return trial + return trials def lock_and_report_trial_evaluation( self, @@ -236,14 +274,37 @@ def lock_and_report_trial_evaluation( with self._trial_lock.lock(), self._err_lock.lock(): self._report_trial_evaluation(trial, report, worker_id=worker_id) + @overload def _sample_trial( self, optimizer: BaseOptimizer, *, worker_id: str, trials: dict[str, Trial], + n: int, + _sample_hooks: list[Callable] | None = ..., + ) -> list[Trial]: ... + + @overload + def _sample_trial( + self, + optimizer: BaseOptimizer, + *, + worker_id: str, + trials: dict[str, Trial], + n: None, + _sample_hooks: list[Callable] | None = ..., + ) -> Trial: ... + + def _sample_trial( + self, + optimizer: BaseOptimizer, + *, + worker_id: str, + trials: dict[str, Trial], + n: int | None, _sample_hooks: list[Callable] | None = None, - ) -> Trial: + ) -> Trial | list[Trial]: """Sample a new trial from the optimizer. !!! warning @@ -253,6 +314,7 @@ def _sample_trial( Args: optimizer: The optimizer to sample the trial from. worker_id: The worker that is sampling the trial. + n: The number of trials to sample. trials: The current trials. _sample_hooks: A list of hooks to apply to the optimizer before sampling. @@ -268,8 +330,9 @@ def _sample_trial( # it to be done after `load_results`... I hope not. if _sample_hooks is not None: for hook in _sample_hooks: - optimizer = hook(optimizer) + optimizer = hook(optimizer) # type: ignore + assert isinstance(optimizer, BaseOptimizer) if opt_state.budget is not None: # NOTE: All other values of budget are ones that should remain # constant, there are currently only these two which are dynamic as @@ -281,43 +344,51 @@ def _sample_trial( ) opt_state.budget.used_evaluations = len(trials) - sampled_config_maybe_new_opt_state = optimizer.ask( + sampled_configs = optimizer.ask( trials=trials, budget_info=opt_state.budget.clone(), + n=n, ) - if isinstance(sampled_config_maybe_new_opt_state, tuple): - sampled_config, shared_state = sampled_config_maybe_new_opt_state - else: - sampled_config = sampled_config_maybe_new_opt_state - shared_state = opt_state.shared_state - - if sampled_config.previous_config_id is not None: - previous_trial = trials.get(sampled_config.previous_config_id) - if previous_trial is None: - raise ValueError( - f"Previous trial '{sampled_config.previous_config_id}' not found." - ) - previous_trial_location = previous_trial.metadata.location - else: - previous_trial_location = None - - trial = Trial.new( - trial_id=sampled_config.id, - location="", # HACK: This will be set by the `TrialRepo` in `put_new` - config=sampled_config.config, - previous_trial=sampled_config.previous_config_id, - previous_trial_location=previous_trial_location, - time_sampled=time.time(), - worker_id=worker_id, - ) + if not isinstance(sampled_configs, list): + sampled_configs = [sampled_configs] + + # TODO: Not implemented yet. + shared_state = opt_state.shared_state + + sampled_trials: list[Trial] = [] + for sampled_config in sampled_configs: + if sampled_config.previous_config_id is not None: + previous_trial = trials.get(sampled_config.previous_config_id) + if previous_trial is None: + raise ValueError( + f"Previous trial '{sampled_config.previous_config_id}' not found." + ) + previous_trial_location = previous_trial.metadata.location + else: + previous_trial_location = None + + trial = Trial.new( + trial_id=sampled_config.id, + location="", # HACK: This will be set by the `TrialRepo` in `put_new` + config=sampled_config.config, + previous_trial=sampled_config.previous_config_id, + previous_trial_location=previous_trial_location, + time_sampled=time.time(), + worker_id=worker_id, + ) + sampled_trials.append(trial) opt_state.shared_state = shared_state opt_state.seed_snapshot.recapture() with self._optimizer_state_path.open("wb") as f: pickle.dump(opt_state, f, protocol=pickle.HIGHEST_PROTOCOL) - return trial + if n is None: + assert len(sampled_trials) == 1 + return sampled_trials[0] + + return sampled_trials def _report_trial_evaluation( self, diff --git a/neps/state/settings.py b/neps/state/settings.py index f34a9435..167c24f9 100644 --- a/neps/state/settings.py +++ b/neps/state/settings.py @@ -72,6 +72,9 @@ class WorkerSettings: default_report_values: DefaultReportValues """Values to use when an error occurs or was not specified.""" + batch_size: int | None + """The number of configurations to sample in a single batch.""" + # --------- Global Stopping Criterion --------- max_evaluations_total: int | None """The maximum number of evaluations to run in total. diff --git a/neps/utils/run_args.py b/neps/utils/run_args.py index 9f5cc60d..385555b6 100644 --- a/neps/utils/run_args.py +++ b/neps/utils/run_args.py @@ -531,6 +531,7 @@ def __init__(self, func_args: dict, yaml_args: Path | str | Default | None = Non self.cost_value_on_error = UNSET self.pre_load_hooks = UNSET self.searcher = UNSET + self.sample_batch_size = UNSET self.searcher_kwargs = UNSET if not isinstance(yaml_args, Default) and yaml_args is not None: From 722ef9d7fa611687252990b18f700363b7ee97ef Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Wed, 11 Dec 2024 10:51:31 +0100 Subject: [PATCH 55/56] fix: Initial design size --- .../bayesian_optimization/optimizer.py | 59 ++++++++++--------- neps/optimizers/initial_design.py | 7 +-- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/neps/optimizers/bayesian_optimization/optimizer.py b/neps/optimizers/bayesian_optimization/optimizer.py index 4c8aa1f9..6755e943 100644 --- a/neps/optimizers/bayesian_optimization/optimizer.py +++ b/neps/optimizers/bayesian_optimization/optimizer.py @@ -145,40 +145,46 @@ def ask( ) -> SampledConfig | list[SampledConfig]: _n = 1 if n is None else n n_sampled = len(trials) - config_ids = iter(str(n_sampled + i) for i in range(_n + 1)) + config_ids = iter(str(i + 1) for i in range(n_sampled, n_sampled + _n)) space = self.pipeline_space - # If we havn't passed the intial design phase - if self.init_design is None: - if n is not None and self.n_initial_design < n: - init_design_size = n - else: - init_design_size = self.n_initial_design - self.init_design = make_initial_design( + sampled_configs: list[SampledConfig] = [] + + # If the amount of configs evaluated is less than the initial design + # requirement, keep drawing from initial design + n_evaluated = sum( + 1 + for trial in trials.values() + if trial.report is not None and trial.report.loss is not None + ) + if n_evaluated < self.n_initial_design: + design_samples = make_initial_design( space=space, encoder=self.encoder, - sample_default_first=self.sample_default_first, - sampler=self.prior if self.prior is not None else "sobol", + sample_default_first=( + self.sample_default_first if n_sampled == 0 else False + ), + sampler=self.prior if self.prior is not None else "uniform", seed=None, # TODO: Seeding - sample_size=init_design_size, + sample_size=_n, sample_fidelity="max", ) - sampled_configs = [ - SampledConfig(id=config_id, config=config) - for config_id, config in zip( - config_ids, self.init_design[n_sampled : n_sampled + _n], strict=False + sampled_configs.extend( + [ + SampledConfig(id=config_id, config=config) + for config_id, config in zip( + config_ids, + design_samples, + strict=False, + ) + ] ) - ] - if len(sampled_configs) == _n: - if n is None: - return sampled_configs[0] - - return sampled_configs + if len(sampled_configs) == _n: + if n is None: + return sampled_configs[0] - assert len(sampled_configs) < _n - - _n = _n - len(sampled_configs) + return sampled_configs # Otherwise, we encode trials and setup to fit and acquire from a GP data, encoder = encode_trials_for_gp( @@ -202,7 +208,7 @@ def ask( prior = None if self.prior: pibo_exp_term = _pibo_exp_term( - n_sampled, encoder.ncols, len(self.init_design) + n_sampled, encoder.ncols, self.n_initial_design ) # If the exp term is insignificant, skip prior acq. weighting prior = None if pibo_exp_term < 1e-4 else self.prior @@ -230,9 +236,6 @@ def ask( costs_on_log_scale=self.cost_on_log_scale, ) - config_ids = list(config_ids) - print(_n, len(candidates), len(config_ids)) # noqa: T201 - configs = encoder.decode(candidates) sampled_configs.extend( [ diff --git a/neps/optimizers/initial_design.py b/neps/optimizers/initial_design.py index ee64fd2a..19892d43 100644 --- a/neps/optimizers/initial_design.py +++ b/neps/optimizers/initial_design.py @@ -100,12 +100,7 @@ def make_initial_design( # noqa: PLR0912, C901 ) if sample_default_first: - # TODO: No way to pass a seed to the sampler - default = { - name: hp.default if hp.default is not None else hp.sample_value() - for name, hp in space.hyperparameters.items() - } - configs.append({**default, **fids()}) + configs.append({**space.default_config, **fids()}) ndims = len(space.numerical) + len(space.categoricals) if sample_size == "ndim": From f62909624682e6894762d0e744d389203c966be2 Mon Sep 17 00:00:00 2001 From: eddiebergman Date: Wed, 11 Dec 2024 11:30:25 +0100 Subject: [PATCH 56/56] ux: Improve log message for lock timeout --- neps/runtime.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/neps/runtime.py b/neps/runtime.py index 6ea490b6..f4f078f5 100644 --- a/neps/runtime.py +++ b/neps/runtime.py @@ -20,6 +20,8 @@ TypeVar, ) +from portalocker import portalocker + from neps.env import ( FS_SYNC_GRACE_BASE, FS_SYNC_GRACE_INC, @@ -529,12 +531,24 @@ def run(self) -> None: # noqa: C901, PLR0915 _repeated_fail_get_next_trial_count = 0 except Exception as e: _repeated_fail_get_next_trial_count += 1 - logger.debug( - "Worker '%s': Error while trying to get the next trial to evaluate.", - self.worker_id, - exc_info=True, - ) - time.sleep(1) # Help stagger retries + if isinstance(e, portalocker.exceptions.LockException): + logger.debug( + "Worker '%s': Timeout while trying to get the next trial to" + " evaluate. If you are using a model based optimizer, such as" + " Bayesian Optimization, this can occur as the number of" + " configurations get large. There's not much to do here" + " and we will retry to obtain the lock.", + self.worker_id, + exc_info=True, + ) + else: + logger.debug( + "Worker '%s': Error while trying to get the next trial to" + " evaluate.", + self.worker_id, + exc_info=True, + ) + time.sleep(1) # Help stagger retries # NOTE: This is to prevent any infinite loops if we can't get a trial if _repeated_fail_get_next_trial_count >= MAX_RETRIES_GET_NEXT_TRIAL: raise WorkerFailedToGetPendingTrialsError(