-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Brings all the solvers from the hdbo repo * Adds a test env for BAxUS inside tox and its GitHub action * Skips baxus in the usual tests * Updates readme * Adds tests for all Ax-related solvers, and their GitHub action * Updates the test action for Ax * Skips the ax tests in the base testenv * Updates poli import * Skips ax solvers by checking whether the module was installed * Adds docs for the Ax-related solvers * Adds testing for probabilistic reparametrization * Adds tests for Bounce * Updates the env path for bounce * Adds Bounce to the table in the readme
- Loading branch information
1 parent
f3319ce
commit da3e725
Showing
38 changed files
with
1,250 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Ax (py3.10 in conda) | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build-linux: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 5 | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.9 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.9' | ||
- name: Add conda to system path | ||
run: | | ||
# $CONDA is an environment variable pointing to the root of the miniconda directory | ||
echo $CONDA/bin >> $GITHUB_PATH | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install tox | ||
- name: Test Ax-related solvers with tox | ||
run: | | ||
tox -e poli-ax-base-py39 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: BAxUS (py3.10 in conda) | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build-linux: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 5 | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.9 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.9' | ||
- name: Add conda to system path | ||
run: | | ||
# $CONDA is an environment variable pointing to the root of the miniconda directory | ||
echo $CONDA/bin >> $GITHUB_PATH | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install tox | ||
- name: Test BAxUS with tox | ||
run: | | ||
tox -e poli-baxus-base-py39 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Bounce (py3.10 in conda) | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build-linux: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 5 | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.9 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.9' | ||
- name: Add conda to system path | ||
run: | | ||
# $CONDA is an environment variable pointing to the root of the miniconda directory | ||
echo $CONDA/bin >> $GITHUB_PATH | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install tox | ||
- name: Test bounce with tox | ||
run: | | ||
tox -e poli-bounce-base-py39 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Prob. Rep. (py3.10 in conda) | ||
|
||
on: [push] | ||
|
||
jobs: | ||
build-linux: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 5 | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python 3.9 | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: '3.9' | ||
- name: Add conda to system path | ||
run: | | ||
# $CONDA is an environment variable pointing to the root of the miniconda directory | ||
echo $CONDA/bin >> $GITHUB_PATH | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install tox | ||
- name: Test PR with tox | ||
run: | | ||
tox -e poli-pr-base-py39 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
name: Test (conda, python 3.9) | ||
name: Base (python 3.9 in conda) | ||
|
||
on: [push] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Tuple | ||
import uuid | ||
|
||
from numpy import ndarray | ||
import numpy as np | ||
|
||
from poli.objective_repository import ToyContinuousBlackBox | ||
from poli.core.abstract_black_box import AbstractBlackBox | ||
|
||
from poli_baselines.core.abstract_solver import AbstractSolver | ||
|
||
from poli_baselines.core.utils.ax.interface import ( | ||
define_search_space, | ||
) | ||
|
||
from ax.service.ax_client import AxClient, ObjectiveProperties | ||
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy | ||
|
||
|
||
class AxSolver(AbstractSolver): | ||
def __init__( | ||
self, | ||
black_box: AbstractBlackBox, | ||
x0: ndarray, | ||
y0: ndarray, | ||
generation_strategy: GenerationStrategy, | ||
bounds: list[tuple[float, float]] | None = None, | ||
noise_std: float = 0.0, | ||
): | ||
super().__init__(black_box, x0, y0) | ||
self.noise_std = noise_std | ||
|
||
if bounds is None: | ||
assert isinstance(black_box, ToyContinuousBlackBox) | ||
bounds_ = [black_box.function.limits] * x0.shape[1] | ||
else: | ||
# If bounds is (lb, up), then we build the bounds | ||
# for the user | ||
if len(bounds) == 2: | ||
assert isinstance(bounds[0], (int, float)) | ||
assert isinstance(bounds[1], (int, float)) | ||
bounds_ = [bounds] * x0.shape[1] | ||
else: | ||
bounds_ = bounds | ||
|
||
assert len(bounds) == x0.shape[1] | ||
assert all(len(bound) == 2 for bound in bounds) | ||
|
||
ax_client = AxClient(generation_strategy=generation_strategy) | ||
exp_id = f"{uuid.uuid4()}"[:8] | ||
|
||
search_space = define_search_space(x0=x0, bounds=bounds_) | ||
|
||
ax_client.create_experiment( | ||
name=f"experiment_on_{black_box.info.name}_{exp_id}", | ||
parameters=[ | ||
{ | ||
"name": param.name, | ||
"type": "range", | ||
"bounds": [param.lower, param.upper], | ||
"value_type": "float", | ||
} | ||
for param in search_space.parameters.values() | ||
], | ||
objectives={black_box.info.name: ObjectiveProperties(minimize=False)}, | ||
) | ||
|
||
def evaluate( | ||
parametrization: dict[str, float] | ||
) -> dict[str, tuple[float, float]]: | ||
x = np.array([[parametrization[f"x{i}"] for i in range(x0.shape[1])]]) | ||
y = black_box(x) | ||
return {black_box.info.name: (y.flatten()[0], self.noise_std)} | ||
|
||
self.evaluate = evaluate | ||
|
||
# Run initialization with x0 and y0 | ||
for x, y in zip(x0, y0): | ||
params = {f"x{i}": float(x_i) for i, x_i in enumerate(x)} | ||
_, trial_index = ax_client.attach_trial(params) | ||
ax_client.complete_trial( | ||
trial_index=trial_index, | ||
raw_data={black_box.info.name: (y[0], self.noise_std)}, | ||
) | ||
|
||
print(ax_client.get_trials_data_frame()) | ||
self.ax_client = ax_client | ||
|
||
def solve( | ||
self, | ||
max_iter: int = 100, | ||
verbose: bool = False, | ||
) -> Tuple[np.ndarray, np.ndarray]: | ||
for i in range(max_iter): | ||
parameters, trial_index = self.ax_client.get_next_trial() | ||
val = self.evaluate(parameters) | ||
self.ax_client.complete_trial( | ||
trial_index=trial_index, | ||
raw_data=val, | ||
) | ||
# df = self.ax_client.get_trials_data_frame() | ||
|
||
if verbose: | ||
print( | ||
f"Iteration: {i}, Value in iteration: {val[self.black_box.info.name][0]:.3f}, Best so far: {self.ax_client.get_trials_data_frame()[self.black_box.info.name].max():.3f}" | ||
) | ||
|
||
# TODO: fix this return | ||
return self.ax_client.get_trials_data_frame() # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
name: poli__ax | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- python=3.10 | ||
- pip | ||
- pip: | ||
- scikit-learn | ||
- botorch | ||
- ax-platform | ||
- numpy | ||
- "git+https://github.com/MachineLearningLifeScience/poli.git@dev" | ||
- "git+https://github.com/MachineLearningLifeScience/poli-baselines@main" |
Oops, something went wrong.