From 9bf350769a3d3d8389b8b548b46e83520fd0d4e7 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Tue, 3 Dec 2024 15:18:28 -0800 Subject: [PATCH] Combine experiment and gs fields into AnalysisBase (#3137) Summary: AnalysisBase will have optional `_experiment` and `_generation_strategy` fields, with getter and setter properties, as well as the `standard_generation_strategy` prop from `Scheduler`. `Scheduler` and `AxClient` will inherit these from it. The naming is less than optimal with AnalysisBase holding the experiment and GS. This is otherwise a no-op change designed to reduce pyre errors. Differential Revision: D66712036 --- ax/modelbridge/tests/test_prediction_utils.py | 8 +-- ax/service/ax_client.py | 56 +++++++------------ ax/service/scheduler.py | 17 ------ ax/service/tests/test_ax_client.py | 45 ++++++++------- ax/service/utils/analysis_base.py | 54 +++++++++++++++++- 5 files changed, 100 insertions(+), 80 deletions(-) diff --git a/ax/modelbridge/tests/test_prediction_utils.py b/ax/modelbridge/tests/test_prediction_utils.py index 0354cdffe47..d1929867f1e 100644 --- a/ax/modelbridge/tests/test_prediction_utils.py +++ b/ax/modelbridge/tests/test_prediction_utils.py @@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None: observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5}) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, ) @@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 1) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1", "test_metric2", "test_metric:agg"}, scalarized_metric_config=[ @@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None: self.assertEqual(len(se_hat), 3) y_hat, se_hat = predict_at_point( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), obsf=observation_features, metric_names={"test_metric1"}, scalarized_metric_config=[ @@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None: 20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}), } predictions_map = predict_by_features( - model=none_throws(ax_client.generation_strategy.model), + model=none_throws(ax_client.standard_generation_strategy.model), label_to_feature_dict=observation_features_dict, metric_names={"test_metric1"}, ) diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 4e48cddb621..67881c0d1d3 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -599,8 +599,8 @@ def get_next_trial( # TODO[T79183560]: Ensure correct handling of generator run when using # foreign keys. self._update_generation_strategy_in_db_if_possible( - generation_strategy=self.generation_strategy, - new_generator_runs=[self.generation_strategy._generator_runs[-1]], + generation_strategy=self.standard_generation_strategy, + new_generator_runs=[self.standard_generation_strategy._generator_runs[-1]], ) return none_throws(trial.arm).parameters, trial.index @@ -625,7 +625,7 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]: if self.generation_strategy._experiment is None: self.generation_strategy.experiment = self.experiment - return self.generation_strategy.current_generator_run_limit() + return self.standard_generation_strategy.current_generator_run_limit() def get_next_trials( self, @@ -950,7 +950,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]: Mapping of form {num_trials -> max_parallelism_setting}. """ parallelism_settings = [] - for step in self.generation_strategy._steps: + for step in self.standard_generation_strategy._steps: parallelism_settings.append( (step.num_trials, step.max_parallelism or step.num_trials) ) @@ -1071,7 +1071,7 @@ def get_contour_plot( raise ValueError( f'Metric "{metric_name}" is not associated with this optimization.' ) - if self.generation_strategy.model is not None: + if self.standard_generation_strategy.model is not None: try: logger.info( f"Retrieving contour plot with parameter '{param_x}' on X-axis " @@ -1079,7 +1079,7 @@ def get_contour_plot( "Remaining parameters are affixed to the middle of their range." ) return plot_contour( - model=none_throws(self.generation_strategy.model), + model=none_throws(self.standard_generation_strategy.model), param_x=param_x, param_y=param_y, metric_name=metric_name, @@ -1089,8 +1089,8 @@ def get_contour_plot( # Some models don't implement '_predict', which is needed # for the contour plots. logger.info( - f"Model {self.generation_strategy.model} does not implement " - "`predict`, so it cannot be used to generate a response " + f"Model {self.standard_generation_strategy.model} does not " + "implement `predict`, so it cannot be used to generate a response " "surface plot." ) raise UnsupportedPlotError( @@ -1112,14 +1112,14 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig: """ if not self.experiment.trials: raise ValueError("Cannot generate plot as there are no trials.") - cur_model = self.generation_strategy.model + cur_model = self.standard_generation_strategy.model if cur_model is not None: try: return plot_feature_importance_by_feature(cur_model, relative=relative) except NotImplementedError: logger.info( - f"Model {self.generation_strategy.model} does not implement " - "`feature_importances`, so it cannot be used to generate " + f"Model {self.standard_generation_strategy.model} does not " + "implement `feature_importances`, so it cannot be used to generate " "this plot. Only certain models, implement feature importances." ) @@ -1247,7 +1247,8 @@ def get_model_predictions( else set(none_throws(self.experiment.metrics).keys()) ) model = none_throws( - self.generation_strategy.model, "No model has been instantiated yet." + self.standard_generation_strategy.model, + "No model has been instantiated yet.", ) # Construct a dictionary that maps from a label to an @@ -1306,8 +1307,8 @@ def fit_model(self) -> None: "At least one trial must be completed with data to fit a model." ) # Check if we should transition before generating the next candidate. - self.generation_strategy._maybe_transition_to_next_node() - self.generation_strategy._fit_current_model(data=None) + self.standard_generation_strategy._maybe_transition_to_next_node() + self.standard_generation_strategy._fit_current_model(data=None) def verify_trial_parameterization( self, trial_index: int, parameterization: TParameterization @@ -1496,29 +1497,10 @@ def from_json_snapshot( # ---------------------- Private helper methods. --------------------- - @property - def experiment(self) -> Experiment: - """Returns the experiment set on this Ax client.""" - return none_throws( - self._experiment, - ( - "Experiment not set on Ax client. Must first " - "call load_experiment or create_experiment to use handler functions." - ), - ) - def get_trial(self, trial_index: int) -> Trial: """Return a trial on experiment cast as Trial""" return checked_cast(Trial, self.experiment.trials[trial_index]) - @property - def generation_strategy(self) -> GenerationStrategy: - """Returns the generation strategy, set on this experiment.""" - return none_throws( - self._generation_strategy, - "No generation strategy has been set on this optimization yet.", - ) - @property def objective(self) -> Objective: return none_throws(self.experiment.optimization_config).objective @@ -1586,7 +1568,7 @@ def get_best_trial( ) -> tuple[int, TParameterization, TModelPredictArm | None] | None: return self._get_best_trial( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, trial_indices=trial_indices, use_model_predictions=use_model_predictions, ) @@ -1600,7 +1582,7 @@ def get_pareto_optimal_parameters( ) -> dict[int, tuple[TParameterization, TModelPredictArm]]: return self._get_pareto_optimal_parameters( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, trial_indices=trial_indices, use_model_predictions=use_model_predictions, ) @@ -1614,7 +1596,7 @@ def get_hypervolume( ) -> float: return BestPointMixin._get_hypervolume( experiment=self.experiment, - generation_strategy=self.generation_strategy, + generation_strategy=self.standard_generation_strategy, optimization_config=optimization_config, trial_indices=trial_indices, use_model_predictions=use_model_predictions, @@ -1817,7 +1799,7 @@ def _gen_new_generator_run( else None ) with with_rng_seed(seed=self._random_seed): - return none_throws(self.generation_strategy).gen( + return none_throws(self.standard_generation_strategy).gen( experiment=self.experiment, n=n, pending_observations=self._get_pending_observation_features( diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index dbc596b281b..d30b2aa138a 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -162,8 +162,6 @@ class Scheduler(AnalysisBase, BestPointMixin): been saved, as otherwise experiment state could get corrupted.** """ - experiment: Experiment - generation_strategy: GenerationStrategyInterface # pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter. logger: LoggerAdapter # Mapping of form {short string identifier -> message to show in reported @@ -491,21 +489,6 @@ def runner(self) -> Runner: ) return runner - @property - def standard_generation_strategy(self) -> GenerationStrategy: - """Used for operations in the scheduler that can only be done with - and instance of ``GenerationStrategy``. - """ - gs = self.generation_strategy - if not isinstance(gs, GenerationStrategy): - raise NotImplementedError( - "This functionality is only supported with instances of " - "`GenerationStrategy` (one that uses `GenerationStrategy` " - "class) and not yet with other types of " - "`GenerationStrategyInterface`." - ) - return gs - def __repr__(self) -> str: """Short user-friendly string representation.""" if not hasattr(self, "experiment"): diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index efa67476ed7..ae27cf55705 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -498,7 +498,10 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: """ ax_client = get_branin_optimization() self.assertEqual( - [s.model for s in none_throws(ax_client.generation_strategy)._steps], + [ + s.model + for s in none_throws(ax_client.standard_generation_strategy)._steps + ], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -713,7 +716,7 @@ def test_default_generation_strategy_continuous_for_moo( }, ) self.assertEqual( - [s.model for s in none_throws(ax_client.generation_strategy)._steps], + [s.model for s in ax_client.standard_generation_strategy._steps], [Models.SOBOL, Models.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): @@ -776,7 +779,7 @@ def test_create_experiment(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1013,7 +1016,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1074,7 +1077,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None: def test_create_experiment_with_metric_definitions(self) -> None: """Test basic experiment creation.""" ax_client = AxClient() - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment metric_definitions = { @@ -1341,7 +1344,7 @@ def test_create_moo_experiment(self) -> None: steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] ) ) - with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): + with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"): ax_client.experiment ax_client.create_experiment( name="test_experiment", @@ -1575,10 +1578,9 @@ def test_keep_generating_without_data(self) -> None: {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, ], ) - self.assertFalse( - ax_client.generation_strategy._steps[0].enforce_num_trials, False - ) - self.assertFalse(ax_client.generation_strategy._steps[1].max_parallelism, None) + gs = ax_client.standard_generation_strategy + self.assertFalse(gs._steps[0].enforce_num_trials, False) + self.assertFalse(gs._steps[1].max_parallelism, None) for _ in range(10): parameterization, trial_index = ax_client.get_next_trial() @@ -2094,14 +2096,14 @@ def test_sqa_storage(self) -> None: # pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U... raw_data=branin(*parameters.values()), ) - gs = ax_client.generation_strategy + gs = ax_client.standard_generation_strategy ax_client = AxClient(db_settings=db_settings) ax_client.load_experiment_from_database("test_experiment") # Some fields of the reloaded GS are not expected to be set (both will be # set during next model fitting call), so we unset them on the original GS as # well. gs._unset_non_persistent_state_fields() - ax_client.generation_strategy._unset_non_persistent_state_fields() + ax_client.standard_generation_strategy._unset_non_persistent_state_fields() self.assertEqual(gs, ax_client.generation_strategy) with self.assertRaises(ValueError): # Overwriting existing experiment. @@ -2455,8 +2457,9 @@ def helper_test_get_pareto_optimal_points( num_trials=20, outcome_constraints=outcome_constraints ) ax_client.fit_model() + gs = ax_client.standard_generation_strategy self.assertEqual( - ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + gs._curr.model_spec_to_gen_from.model_key, "BoTorch", ) @@ -2481,7 +2484,7 @@ def helper_test_get_pareto_optimal_points( # This overwrites the `predict` call to return the original observations, # while testing the rest of the code as if we're using predictions. # pyre-fixme[16]: `Optional` has no attribute `model`. - model = ax_client.generation_strategy.model.model + model = ax_client.standard_generation_strategy.model.model ys = model.surrogate.training_data[0].Y with patch.object( model, "predict", return_value=(ys, torch.zeros(*ys.shape, ys.shape[-1])) @@ -2525,8 +2528,9 @@ def helper_test_get_pareto_optimal_points_from_sobol_step( ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints ) + gs = ax_client.standard_generation_strategy self.assertEqual( - ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key, + gs._curr.model_spec_to_gen_from.model_key, "Sobol", ) @@ -2637,8 +2641,8 @@ def test_get_pareto_optimal_points_objective_threshold_inference( ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials( num_trials=20, include_objective_thresholds=False ) - ax_client.generation_strategy._maybe_transition_to_next_node() - ax_client.generation_strategy._fit_current_model( + ax_client.standard_generation_strategy._maybe_transition_to_next_node() + ax_client.standard_generation_strategy._fit_current_model( data=ax_client.experiment.lookup_data() ) @@ -2849,7 +2853,8 @@ def test_with_hss(self) -> None: # Make sure we actually tried a Botorch iteration and all the transforms it # applies. self.assertEqual( - ax_client.generation_strategy._generator_runs[-1]._model_key, "BoTorch" + ax_client.standard_generation_strategy._generator_runs[-1]._model_key, + "BoTorch", ) self.assertEqual(len(ax_client.experiment.trials), 6) ax_client.attach_trial( @@ -2964,7 +2969,7 @@ def test_torch_device(self) -> None: torch_device=device, ) ax_client = get_branin_optimization(torch_device=device) - gpei_step_kwargs = ax_client.generation_strategy._steps[1].model_kwargs + gpei_step_kwargs = ax_client.standard_generation_strategy._steps[1].model_kwargs self.assertEqual(gpei_step_kwargs["torch_device"], device) def test_repr_function( @@ -2993,7 +2998,7 @@ def test_gen_fixed_features(self) -> None: name="fixed_features", ) with mock.patch.object( - GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen + GenerationStrategy, "gen", wraps=ax_client.standard_generation_strategy.gen ) as mock_gen: with self.subTest("fixed_features is None"): params, idx = ax_client.get_next_trial() diff --git a/ax/service/utils/analysis_base.py b/ax/service/utils/analysis_base.py index 30958ff7f6c..3f79c3c0c9b 100644 --- a/ax/service/utils/analysis_base.py +++ b/ax/service/utils/analysis_base.py @@ -14,22 +14,26 @@ from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.modelbridge.generation_strategy import GenerationStrategy from ax.service.utils.with_db_settings_base import WithDBSettingsBase from ax.utils.common.typeutils import checked_cast +from pyre_extensions import none_throws class AnalysisBase(WithDBSettingsBase): """ Base class for analysis functionality shared between AxClient and Scheduler. + It also manages the experiment and generation strategy associated with the + instance. """ # pyre-fixme[13]: Attribute `experiment` is declared in class # `AnalysisBase` to have type `Experiment` but is never initialized - experiment: Experiment + _experiment: Experiment | None # pyre-fixme[13]: Attribute `generation_strategy` is declared in class # `AnalysisBase` to have type `GenerationStrategyInterface` but # is never initialized - generation_strategy: GenerationStrategyInterface + _generation_strategy: GenerationStrategyInterface | None def _choose_analyses(self) -> list[Analysis]: """ @@ -95,3 +99,49 @@ def compute_analyses( ) return cards + + @property + def experiment(self) -> Experiment: + """Returns the experiment set on this instance.""" + return none_throws( + self._experiment, + ( + f"Experiment not set on {self.__class__.__name__}. Must first " + "call load_experiment or create_experiment to use handler functions." + ), + ) + + @experiment.setter + def experiment(self, experiment: Experiment) -> None: + """Sets the experiment on this instance.""" + self._experiment = experiment + + @property + def generation_strategy(self) -> GenerationStrategyInterface: + """Returns the generation strategy, set on this experiment.""" + return none_throws( + self._generation_strategy, + "No generation strategy has been set on this optimization yet.", + ) + + @generation_strategy.setter + def generation_strategy( + self, generation_strategy: GenerationStrategyInterface + ) -> None: + """Sets the generation strategy on this instance.""" + self._generation_strategy = generation_strategy + + @property + def standard_generation_strategy(self) -> GenerationStrategy: + """Used for operations in the scheduler that can only be done with + and instance of ``GenerationStrategy``. + """ + gs = self.generation_strategy + if not isinstance(gs, GenerationStrategy): + raise NotImplementedError( + "This functionality is only supported with instances of " + "`GenerationStrategy` (one that uses `GenerationStrategy` " + "class) and not yet with other types of " + "`GenerationStrategyInterface`." + ) + return gs