Skip to content

Commit

Permalink
Implements a multi-observer (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgondu authored May 22, 2024
1 parent 9f26499 commit 8ac4ca6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
3 changes: 0 additions & 3 deletions src/poli/core/util/abstract_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from poli.core.problem_setup_information import ProblemSetupInformation


import numpy as np


class AbstractObserver:
"""
Abstract base class for observers in the poli library.
Expand Down
26 changes: 26 additions & 0 deletions src/poli/core/util/multi_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import numpy as np

from poli.core.black_box_information import BlackBoxInformation
from poli.core.util.abstract_observer import AbstractObserver


class MultiObserver(AbstractObserver):
def __init__(self, observers: list[AbstractObserver]):
self.observers = observers

def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None:
for observer in self.observers:
observer.observe(x, y, context)

def initialize_observer(
self,
problem_setup_info: BlackBoxInformation,
caller_info: object,
x0: np.ndarray,
y0: np.ndarray,
seed: int,
) -> object:
for observer in self.observers:
observer.initialize_observer(problem_setup_info, caller_info, x0, y0, seed)
22 changes: 22 additions & 0 deletions src/poli/tests/observers/test_observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from poli.core.problem_setup_information import ProblemSetupInformation
from poli.core.util.abstract_observer import AbstractObserver
from poli.core.util.multi_observer import MultiObserver

from poli import objective_factory

Expand Down Expand Up @@ -178,5 +179,26 @@ def test_multiple_observer_registration():
delete_observer_run_script(observer_name="simple_2__")


def test_multi_observer_works():
obs_1 = SimpleObserver(experiment_id="example_1")
obs_2 = SimpleObserver(experiment_id="example_2")

multi_observer = MultiObserver(observers=[obs_1, obs_2])

# Creating a black box function
problem = objective_factory.create(name="aloha", observer=multi_observer)

# Evaluating the black box function
problem.black_box(np.array([list("MIGUE")]))

assert obs_1.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}]
assert obs_2.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}]

# Cleaning up (and testing whether we can access attributes
# of the external observer)
(obs_1.experiment_path / "metadata.json").unlink()
(obs_2.experiment_path / "metadata.json").unlink()


if __name__ == "__main__":
test_multiple_observer_registration()

0 comments on commit 8ac4ca6

Please sign in to comment.