Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine experiment and gs fields into AnalysisBase #3137

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ax/modelbridge/tests/test_prediction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_predict_at_point(self) -> None:

observation_features = ObservationFeatures(parameters={"x1": 0.3, "x2": 0.5})
y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1"},
)
Expand All @@ -37,7 +37,7 @@ def test_predict_at_point(self) -> None:
self.assertEqual(len(se_hat), 1)

y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1", "test_metric2", "test_metric:agg"},
scalarized_metric_config=[
Expand All @@ -51,7 +51,7 @@ def test_predict_at_point(self) -> None:
self.assertEqual(len(se_hat), 3)

y_hat, se_hat = predict_at_point(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
obsf=observation_features,
metric_names={"test_metric1"},
scalarized_metric_config=[
Expand All @@ -75,7 +75,7 @@ def test_predict_by_features(self) -> None:
20: ObservationFeatures(parameters={"x1": 0.8, "x2": 0.5}),
}
predictions_map = predict_by_features(
model=none_throws(ax_client.generation_strategy.model),
model=none_throws(ax_client.standard_generation_strategy.model),
label_to_feature_dict=observation_features_dict,
metric_names={"test_metric1"},
)
Expand Down
61 changes: 22 additions & 39 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@
from ax.plot.feature_importances import plot_feature_importance_by_feature
from ax.plot.helper import _format_dict
from ax.plot.trace import optimization_trace_single_method
from ax.service.utils.analysis_base import AnalysisBase
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.instantiation import (
FixedFeatures,
InstantiationBase,
ObjectiveProperties,
)
from ax.service.utils.report_utils import exp_to_df
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
from ax.service.utils.with_db_settings_base import DBSettings
from ax.storage.json_store.decoder import (
generation_strategy_from_json,
object_from_json,
Expand Down Expand Up @@ -108,7 +109,7 @@
)


class AxClient(WithDBSettingsBase, BestPointMixin, InstantiationBase):
class AxClient(AnalysisBase, BestPointMixin, InstantiationBase):
"""
Convenience handler for management of experimentation cycle through a
service-like API. External system manages scheduling of the cycle and makes
Expand Down Expand Up @@ -598,8 +599,8 @@ def get_next_trial(
# TODO[T79183560]: Ensure correct handling of generator run when using
# foreign keys.
self._update_generation_strategy_in_db_if_possible(
generation_strategy=self.generation_strategy,
new_generator_runs=[self.generation_strategy._generator_runs[-1]],
generation_strategy=self.standard_generation_strategy,
new_generator_runs=[self.standard_generation_strategy._generator_runs[-1]],
)
return none_throws(trial.arm).parameters, trial.index

Expand All @@ -624,7 +625,7 @@ def get_current_trial_generation_limit(self) -> tuple[int, bool]:
if self.generation_strategy._experiment is None:
self.generation_strategy.experiment = self.experiment

return self.generation_strategy.current_generator_run_limit()
return self.standard_generation_strategy.current_generator_run_limit()

def get_next_trials(
self,
Expand Down Expand Up @@ -949,7 +950,7 @@ def get_max_parallelism(self) -> list[tuple[int, int]]:
Mapping of form {num_trials -> max_parallelism_setting}.
"""
parallelism_settings = []
for step in self.generation_strategy._steps:
for step in self.standard_generation_strategy._steps:
parallelism_settings.append(
(step.num_trials, step.max_parallelism or step.num_trials)
)
Expand Down Expand Up @@ -1070,15 +1071,15 @@ def get_contour_plot(
raise ValueError(
f'Metric "{metric_name}" is not associated with this optimization.'
)
if self.generation_strategy.model is not None:
if self.standard_generation_strategy.model is not None:
try:
logger.info(
f"Retrieving contour plot with parameter '{param_x}' on X-axis "
f"and '{param_y}' on Y-axis, for metric '{metric_name}'. "
"Remaining parameters are affixed to the middle of their range."
)
return plot_contour(
model=none_throws(self.generation_strategy.model),
model=none_throws(self.standard_generation_strategy.model),
param_x=param_x,
param_y=param_y,
metric_name=metric_name,
Expand All @@ -1088,8 +1089,8 @@ def get_contour_plot(
# Some models don't implement '_predict', which is needed
# for the contour plots.
logger.info(
f"Model {self.generation_strategy.model} does not implement "
"`predict`, so it cannot be used to generate a response "
f"Model {self.standard_generation_strategy.model} does not "
"implement `predict`, so it cannot be used to generate a response "
"surface plot."
)
raise UnsupportedPlotError(
Expand All @@ -1111,14 +1112,14 @@ def get_feature_importances(self, relative: bool = True) -> AxPlotConfig:
"""
if not self.experiment.trials:
raise ValueError("Cannot generate plot as there are no trials.")
cur_model = self.generation_strategy.model
cur_model = self.standard_generation_strategy.model
if cur_model is not None:
try:
return plot_feature_importance_by_feature(cur_model, relative=relative)
except NotImplementedError:
logger.info(
f"Model {self.generation_strategy.model} does not implement "
"`feature_importances`, so it cannot be used to generate "
f"Model {self.standard_generation_strategy.model} does not "
"implement `feature_importances`, so it cannot be used to generate "
"this plot. Only certain models, implement feature importances."
)

Expand Down Expand Up @@ -1246,7 +1247,8 @@ def get_model_predictions(
else set(none_throws(self.experiment.metrics).keys())
)
model = none_throws(
self.generation_strategy.model, "No model has been instantiated yet."
self.standard_generation_strategy.model,
"No model has been instantiated yet.",
)

# Construct a dictionary that maps from a label to an
Expand Down Expand Up @@ -1305,8 +1307,8 @@ def fit_model(self) -> None:
"At least one trial must be completed with data to fit a model."
)
# Check if we should transition before generating the next candidate.
self.generation_strategy._maybe_transition_to_next_node()
self.generation_strategy._fit_current_model(data=None)
self.standard_generation_strategy._maybe_transition_to_next_node()
self.standard_generation_strategy._fit_current_model(data=None)

def verify_trial_parameterization(
self, trial_index: int, parameterization: TParameterization
Expand Down Expand Up @@ -1495,29 +1497,10 @@ def from_json_snapshot(

# ---------------------- Private helper methods. ---------------------

@property
def experiment(self) -> Experiment:
"""Returns the experiment set on this Ax client."""
return none_throws(
self._experiment,
(
"Experiment not set on Ax client. Must first "
"call load_experiment or create_experiment to use handler functions."
),
)

def get_trial(self, trial_index: int) -> Trial:
"""Return a trial on experiment cast as Trial"""
return checked_cast(Trial, self.experiment.trials[trial_index])

@property
def generation_strategy(self) -> GenerationStrategy:
"""Returns the generation strategy, set on this experiment."""
return none_throws(
self._generation_strategy,
"No generation strategy has been set on this optimization yet.",
)

@property
def objective(self) -> Objective:
return none_throws(self.experiment.optimization_config).objective
Expand Down Expand Up @@ -1585,7 +1568,7 @@ def get_best_trial(
) -> tuple[int, TParameterization, TModelPredictArm | None] | None:
return self._get_best_trial(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
)
Expand All @@ -1599,7 +1582,7 @@ def get_pareto_optimal_parameters(
) -> dict[int, tuple[TParameterization, TModelPredictArm]]:
return self._get_pareto_optimal_parameters(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
)
Expand All @@ -1613,7 +1596,7 @@ def get_hypervolume(
) -> float:
return BestPointMixin._get_hypervolume(
experiment=self.experiment,
generation_strategy=self.generation_strategy,
generation_strategy=self.standard_generation_strategy,
optimization_config=optimization_config,
trial_indices=trial_indices,
use_model_predictions=use_model_predictions,
Expand Down Expand Up @@ -1816,7 +1799,7 @@ def _gen_new_generator_run(
else None
)
with with_rng_seed(seed=self._random_seed):
return none_throws(self.generation_strategy).gen(
return none_throws(self.standard_generation_strategy).gen(
experiment=self.experiment,
n=n,
pending_observations=self._get_pending_observation_features(
Expand Down
91 changes: 2 additions & 89 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from __future__ import annotations

import traceback

from collections.abc import Callable, Generator, Iterable, Mapping
from copy import deepcopy
from dataclasses import dataclass
Expand All @@ -20,10 +18,6 @@
from typing import Any, cast, NamedTuple, Optional

import ax.service.utils.early_stopping as early_stopping_utils
import pandas as pd
from ax.analysis.analysis import Analysis, AnalysisCard, AnalysisCardLevel, AnalysisE
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
Expand Down Expand Up @@ -57,6 +51,7 @@
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment
from ax.service.utils.analysis_base import AnalysisBase
from ax.service.utils.best_point_mixin import BestPointMixin
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
Expand All @@ -70,7 +65,6 @@
set_ax_logger_levels,
)
from ax.utils.common.timeutils import current_timestamp_in_millis
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import assert_is_instance, none_throws


Expand Down Expand Up @@ -151,7 +145,7 @@ def append(self, text: str) -> None:
self.text += text


class Scheduler(WithDBSettingsBase, BestPointMixin):
class Scheduler(AnalysisBase, BestPointMixin):
"""Closed-loop manager class for Ax optimization.

Attributes:
Expand All @@ -168,8 +162,6 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
been saved, as otherwise experiment state could get corrupted.**
"""

experiment: Experiment
generation_strategy: GenerationStrategyInterface
# pyre-fixme[24]: Generic type `LoggerAdapter` expects 1 type parameter.
logger: LoggerAdapter
# Mapping of form {short string identifier -> message to show in reported
Expand Down Expand Up @@ -497,21 +489,6 @@ def runner(self) -> Runner:
)
return runner

@property
def standard_generation_strategy(self) -> GenerationStrategy:
"""Used for operations in the scheduler that can only be done with
and instance of ``GenerationStrategy``.
"""
gs = self.generation_strategy
if not isinstance(gs, GenerationStrategy):
raise NotImplementedError(
"This functionality is only supported with instances of "
"`GenerationStrategy` (one that uses `GenerationStrategy` "
"class) and not yet with other types of "
"`GenerationStrategyInterface`."
)
return gs

def __repr__(self) -> str:
"""Short user-friendly string representation."""
if not hasattr(self, "experiment"):
Expand Down Expand Up @@ -679,62 +656,6 @@ def run_all_trials(
idle_callback=idle_callback,
)

def compute_analyses(
self, analyses: Iterable[Analysis] | None = None
) -> list[AnalysisCard]:
"""
Compute Analyses for the Experiment and GenerationStrategy associated with this
Scheduler instance and save them to the DB if possible. If an Analysis fails to
compute (e.g. due to a missing metric), it will be skipped and a warning will
be logged.

Args:
analyses: Analyses to compute. If None, the Scheduler will choose a set of
Analyses to compute based on the Experiment and GenerationStrategy.
"""
analyses = analyses if analyses is not None else self._choose_analyses()

results = [
analysis.compute_result(
experiment=self.experiment, generation_strategy=self.generation_strategy
)
for analysis in analyses
]

# TODO Accumulate Es into their own card, perhaps via unwrap_or_else
cards = [result.unwrap() for result in results if result.is_ok()]

for result in results:
if result.is_err():
e = checked_cast(AnalysisE, result.err)
traceback_str = "".join(
traceback.format_exception(
type(result.err.exception),
e.exception,
e.exception.__traceback__,
)
)
cards.append(
MarkdownAnalysisCard(
name=e.analysis.name,
# It would be better if we could reliably compute the title
# without risking another error
title=f"{e.analysis.name} Error",
subtitle=f"An error occurred while computing {e.analysis}",
attributes=e.analysis.attributes,
blob=traceback_str,
df=pd.DataFrame(),
level=AnalysisCardLevel.DEBUG,
)
)

self._save_analysis_cards_to_db_if_possible(
analysis_cards=cards,
experiment=self.experiment,
)

return cards

def run_trials_and_yield_results(
self,
max_trials: int,
Expand Down Expand Up @@ -1882,14 +1803,6 @@ def _get_next_trials(
trials.append(trial)
return trials, None

def _choose_analyses(self) -> list[Analysis]:
"""
Choose Analyses to compute based on the Experiment, GenerationStrategy, etc.
"""

# TODO Create a useful heuristic for choosing analyses
return [ParallelCoordinatesPlot()]

def _gen_new_trials_from_generation_strategy(
self,
num_trials: int,
Expand Down
Loading
Loading