Skip to content

Commit

Permalink
Combine experiment and gs fields into AnalysisBase (#3137)
Browse files Browse the repository at this point in the history
Summary:

AnalysisBase will have optional `_experiment` and `_generation_strategy` fields, with getter and setter properties, as well as the `standard_generation_strategy` prop from `Scheduler`.  `Scheduler` and `AxClient` will inherit these from it.

The naming is less than optimal with AnalysisBase holding the experiment and GS.

This is otherwise a no-op change designed to reduce pyre errors.

Differential Revision: D66712036
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Dec 3, 2024
1 parent 0f2b010 commit 9bf3507
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 80 deletions.
8 changes: 4 additions & 4 deletions ax/modelbridge/tests/test_prediction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None:

observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5})
y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1"},
)
Expand All @@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None:
self.assertEqual(len(se_hat), 1)

y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1", "test_metric2", "test_metric:agg"},
scalarized_metric_config=[
Expand All @@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None:
self.assertEqual(len(se_hat), 3)

y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1"},
scalarized_metric_config=[
Expand All @@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None:
20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}),
}
predictions_map = predict_by_features(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
label_to_feature_dict=observation_features_dict,
metric_names={"test_metric1"},
)
Expand Down
56 changes: 19 additions & 37 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,8 +599,8 @@ def get_next_trial(
# TODO[T79183560]: Ensure correct handling of generator run when using
# foreign keys.
self._update_generation_strategy_in_db_if_possible(
generation_strategy=self.generation_strategy,
new_generator_runs=[self.generation_strategy._generator_runs[-1]],
generation_strategy=self.standard_generation_strategy,
new_generator_runs=[self.standard_generation_strategy._generator_runs[-1]],
)
return none_throws(trial.arm).parameters, trial.index

Expand All @@ -625,7 +625,7 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]:
if self.generation_strategy._experiment is None:
self.generation_strategy.experiment = self.experiment

return self.generation_strategy.current_generator_run_limit()
return self.standard_generation_strategy.current_generator_run_limit()

def get_next_trials(
self,
Expand Down Expand Up @@ -950,7 +950,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
Mapping of form {num_trials -> max_parallelism_setting}.
"""
parallelism_settings = []
for step in self.generation_strategy._steps:
for step in self.standard_generation_strategy._steps:
parallelism_settings.append(
(step.num_trials, step.max_parallelism or step.num_trials)
)
Expand Down Expand Up @@ -1071,15 +1071,15 @@ def get_contour_plot(
raise ValueError(
f'Metric "{metric_name}" is not associated with this optimization.'
)
if self.generation_strategy.model is not None:
if self.standard_generation_strategy.model is not None:
try:
logger.info(
f"Retrieving contour plot with parameter '{param_x}' on X-axis "
f"and '{param_y}' on Y-axis, for metric '{metric_name}'. "
"Remaining parameters are affixed to the middle of their range."
)
return plot_contour(
model=none_throws(self.generation_strategy.model),
model=none_throws(self.standard_generation_strategy.model),
param_x=param_x,
param_y=param_y,
metric_name=metric_name,
Expand All @@ -1089,8 +1089,8 @@ def get_contour_plot(
# Some models don't implement '_predict', which is needed
# for the contour plots.
logger.info(
f"Model {self.generation_strategy.model} does not implement "
"`predict`, so it cannot be used to generate a response "
f"Model {self.standard_generation_strategy.model} does not "
"implement `predict`, so it cannot be used to generate a response "
"surface plot."
)
raise UnsupportedPlotError(
Expand All @@ -1112,14 +1112,14 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig:
"""
if not self.experiment.trials:
raise ValueError("Cannot generate plot as there are no trials.")
cur_model = self.generation_strategy.model
cur_model = self.standard_generation_strategy.model
if cur_model is not None:
try:
return plot_feature_importance_by_feature(cur_model, relative=relative)
except NotImplementedError:
logger.info(
f"Model {self.generation_strategy.model} does not implement "
"`feature_importances`, so it cannot be used to generate "
f"Model {self.standard_generation_strategy.model} does not "
"implement `feature_importances`, so it cannot be used to generate "
"this plot. Only certain models, implement feature importances."
)

Expand Down Expand Up @@ -1247,7 +1247,8 @@ def get_model_predictions(
else set(none_throws(self.experiment.metrics).keys())
)
model = none_throws(
self.generation_strategy.model, "No model has been instantiated yet."
self.standard_generation_strategy.model,
"No model has been instantiated yet.",
)

# Construct a dictionary that maps from a label to an
Expand Down Expand Up @@ -1306,8 +1307,8 @@ def fit_model(self) -> None:
"At least one trial must be completed with data to fit a model."
)
# Check if we should transition before generating the next candidate.
self.generation_strategy._maybe_transition_to_next_node()
self.generation_strategy._fit_current_model(data=None)
self.standard_generation_strategy._maybe_transition_to_next_node()
self.standard_generation_strategy._fit_current_model(data=None)

def verify_trial_parameterization(
self, trial_index: int, parameterization: TParameterization
Expand Down Expand Up @@ -1496,29 +1497,10 @@ def from_json_snapshot(

# ---------------------- Private helper methods. ---------------------

@property
def experiment(self) -> Experiment:
"""Returns the experiment set on this Ax client."""
return none_throws(
self._experiment,
(
"Experiment not set on Ax client. Must first "
"call load_experiment or create_experiment to use handler functions."
),
)

def get_trial(self, trial_index: int) -> Trial:
"""Return a trial on experiment cast as Trial"""
return checked_cast(Trial, self.experiment.trials[trial_index])

@property
def generation_strategy(self) -> GenerationStrategy:
"""Returns the generation strategy, set on this experiment."""
return none_throws(
self._generation_strategy,
"No generation strategy has been set on this optimization yet.",
)

@property
def objective(self) -> Objective:
return none_throws(self.experiment.optimization_config).objective
Expand Down Expand Up @@ -1586,7 +1568,7 @@ def get_best_trial(
) -> tuple[int, TParameterization, TModelPredictArm | None] | None:
return self._get_best_trial(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
)
Expand All @@ -1600,7 +1582,7 @@ def get_pareto_optimal_parameters(
) -> dict[int, tuple[TParameterization, TModelPredictArm]]:
return self._get_pareto_optimal_parameters(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
)
Expand All @@ -1614,7 +1596,7 @@ def get_hypervolume(
) -> float:
return BestPointMixin._get_hypervolume(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand Down Expand Up @@ -1817,7 +1799,7 @@ def _gen_new_generator_run(
else None
)
with with_rng_seed(seed=self._random_seed):
return none_throws(self.generation_strategy).gen(
return none_throws(self.standard_generation_strategy).gen(
experiment=self.experiment,
n=n,
pending_observations=self._get_pending_observation_features(
Expand Down
17 changes: 0 additions & 17 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,6 @@ class Scheduler(AnalysisBase, BestPointMixin):
been saved, as otherwise experiment state could get corrupted.**
"""

experiment: Experiment
generation_strategy: GenerationStrategyInterface
# pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter.
logger: LoggerAdapter
# Mapping of form {short string identifier -> message to show in reported
Expand Down Expand Up @@ -491,21 +489,6 @@ def runner(self) -> Runner:
)
return runner

@property
def standard_generation_strategy(self) -> GenerationStrategy:
"""Used for operations in the scheduler that can only be done with
and instance of ``GenerationStrategy``.
"""
gs = self.generation_strategy
if not isinstance(gs, GenerationStrategy):
raise NotImplementedError(
"This functionality is only supported with instances of "
"`GenerationStrategy` (one that uses `GenerationStrategy` "
"class) and not yet with other types of "
"`GenerationStrategyInterface`."
)
return gs

def __repr__(self) -> str:
"""Short user-friendly string representation."""
if not hasattr(self, "experiment"):
Expand Down
45 changes: 25 additions & 20 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,10 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None:
"""
ax_client = get_branin_optimization()
self.assertEqual(
[s.model for s in none_throws(ax_client.generation_strategy)._steps],
[
s.model
for s in none_throws(ax_client.standard_generation_strategy)._steps
],
[Models.SOBOL, Models.BOTORCH_MODULAR],
)
with self.assertRaisesRegex(ValueError, ".* no trials"):
Expand Down Expand Up @@ -713,7 +716,7 @@ def test_default_generation_strategy_continuous_for_moo(
},
)
self.assertEqual(
[s.model for s in none_throws(ax_client.generation_strategy)._steps],
[s.model for s in ax_client.standard_generation_strategy._steps],
[Models.SOBOL, Models.BOTORCH_MODULAR],
)
with self.assertRaisesRegex(ValueError, ".* no trials"):
Expand Down Expand Up @@ -776,7 +779,7 @@ def test_create_experiment(self) -> None:
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
)
)
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
ax_client.experiment
ax_client.create_experiment(
name="test_experiment",
Expand Down Expand Up @@ -1013,7 +1016,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
)
)
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
ax_client.experiment
ax_client.create_experiment(
name="test_experiment",
Expand Down Expand Up @@ -1074,7 +1077,7 @@ def test_create_single_objective_experiment_with_objectives_dict(self) -> None:
def test_create_experiment_with_metric_definitions(self) -> None:
"""Test basic experiment creation."""
ax_client = AxClient()
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
ax_client.experiment

metric_definitions = {
Expand Down Expand Up @@ -1341,7 +1344,7 @@ def test_create_moo_experiment(self) -> None:
steps=[GenerationStep(model=Models.SOBOL, num_trials=30)]
)
)
with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"):
with self.assertRaisesRegex(AssertionError, "Experiment not set on AxClient"):
ax_client.experiment
ax_client.create_experiment(
name="test_experiment",
Expand Down Expand Up @@ -1575,10 +1578,9 @@ def test_keep_generating_without_data(self) -> None:
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
)
self.assertFalse(
ax_client.generation_strategy._steps[0].enforce_num_trials, False
)
self.assertFalse(ax_client.generation_strategy._steps[1].max_parallelism, None)
gs = ax_client.standard_generation_strategy
self.assertFalse(gs._steps[0].enforce_num_trials, False)
self.assertFalse(gs._steps[1].max_parallelism, None)
for _ in range(10):
parameterization, trial_index = ax_client.get_next_trial()

Expand Down Expand Up @@ -2094,14 +2096,14 @@ def test_sqa_storage(self) -> None:
# pyre-fixme[6]: For 2nd param expected `Union[List[Tuple[Dict[str, U...
raw_data=branin(*parameters.values()),
)
gs = ax_client.generation_strategy
gs = ax_client.standard_generation_strategy
ax_client = AxClient(db_settings=db_settings)
ax_client.load_experiment_from_database("test_experiment")
# Some fields of the reloaded GS are not expected to be set (both will be
# set during next model fitting call), so we unset them on the original GS as
# well.
gs._unset_non_persistent_state_fields()
ax_client.generation_strategy._unset_non_persistent_state_fields()
ax_client.standard_generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(gs, ax_client.generation_strategy)
with self.assertRaises(ValueError):
# Overwriting existing experiment.
Expand Down Expand Up @@ -2455,8 +2457,9 @@ def helper_test_get_pareto_optimal_points(
num_trials=20, outcome_constraints=outcome_constraints
)
ax_client.fit_model()
gs = ax_client.standard_generation_strategy
self.assertEqual(
ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key,
gs._curr.model_spec_to_gen_from.model_key,
"BoTorch",
)

Expand All @@ -2481,7 +2484,7 @@ def helper_test_get_pareto_optimal_points(
# This overwrites the `predict` call to return the original observations,
# while testing the rest of the code as if we're using predictions.
# pyre-fixme[16]: `Optional` has no attribute `model`.
model = ax_client.generation_strategy.model.model
model = ax_client.standard_generation_strategy.model.model
ys = model.surrogate.training_data[0].Y
with patch.object(
model, "predict", return_value=(ys, torch.zeros(*ys.shape, ys.shape[-1]))
Expand Down Expand Up @@ -2525,8 +2528,9 @@ def helper_test_get_pareto_optimal_points_from_sobol_step(
ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials(
num_trials=20, minimize=minimize, outcome_constraints=outcome_constraints
)
gs = ax_client.standard_generation_strategy
self.assertEqual(
ax_client.generation_strategy._curr.model_spec_to_gen_from.model_key,
gs._curr.model_spec_to_gen_from.model_key,
"Sobol",
)

Expand Down Expand Up @@ -2637,8 +2641,8 @@ def test_get_pareto_optimal_points_objective_threshold_inference(
ax_client, _ = get_branin_currin_optimization_with_N_sobol_trials(
num_trials=20, include_objective_thresholds=False
)
ax_client.generation_strategy._maybe_transition_to_next_node()
ax_client.generation_strategy._fit_current_model(
ax_client.standard_generation_strategy._maybe_transition_to_next_node()
ax_client.standard_generation_strategy._fit_current_model(
data=ax_client.experiment.lookup_data()
)

Expand Down Expand Up @@ -2849,7 +2853,8 @@ def test_with_hss(self) -> None:
# Make sure we actually tried a Botorch iteration and all the transforms it
# applies.
self.assertEqual(
ax_client.generation_strategy._generator_runs[-1]._model_key, "BoTorch"
ax_client.standard_generation_strategy._generator_runs[-1]._model_key,
"BoTorch",
)
self.assertEqual(len(ax_client.experiment.trials), 6)
ax_client.attach_trial(
Expand Down Expand Up @@ -2964,7 +2969,7 @@ def test_torch_device(self) -> None:
torch_device=device,
)
ax_client = get_branin_optimization(torch_device=device)
gpei_step_kwargs = ax_client.generation_strategy._steps[1].model_kwargs
gpei_step_kwargs = ax_client.standard_generation_strategy._steps[1].model_kwargs
self.assertEqual(gpei_step_kwargs["torch_device"], device)

def test_repr_function(
Expand Down Expand Up @@ -2993,7 +2998,7 @@ def test_gen_fixed_features(self) -> None:
name="fixed_features",
)
with mock.patch.object(
GenerationStrategy, "gen", wraps=ax_client.generation_strategy.gen
GenerationStrategy, "gen", wraps=ax_client.standard_generation_strategy.gen
) as mock_gen:
with self.subTest("fixed_features is None"):
params, idx = ax_client.get_next_trial()
Expand Down
Loading

0 comments on commit 9bf3507

Please sign in to comment.