diff --git a/src/poli/core/util/abstract_observer.py b/src/poli/core/util/abstract_observer.py index ad021dea..fc322a1d 100644 --- a/src/poli/core/util/abstract_observer.py +++ b/src/poli/core/util/abstract_observer.py @@ -4,9 +4,6 @@ from poli.core.problem_setup_information import ProblemSetupInformation -import numpy as np - - class AbstractObserver: """ Abstract base class for observers in the poli library. diff --git a/src/poli/core/util/multi_observer.py b/src/poli/core/util/multi_observer.py new file mode 100644 index 00000000..a513e763 --- /dev/null +++ b/src/poli/core/util/multi_observer.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import numpy as np + +from poli.core.black_box_information import BlackBoxInformation +from poli.core.util.abstract_observer import AbstractObserver + + +class MultiObserver(AbstractObserver): + def __init__(self, observers: list[AbstractObserver]): + self.observers = observers + + def observe(self, x: np.ndarray, y: np.ndarray, context=None) -> None: + for observer in self.observers: + observer.observe(x, y, context) + + def initialize_observer( + self, + problem_setup_info: BlackBoxInformation, + caller_info: object, + x0: np.ndarray, + y0: np.ndarray, + seed: int, + ) -> object: + for observer in self.observers: + observer.initialize_observer(problem_setup_info, caller_info, x0, y0, seed) diff --git a/src/poli/tests/observers/test_observers.py b/src/poli/tests/observers/test_observers.py index f9895222..73a188d3 100644 --- a/src/poli/tests/observers/test_observers.py +++ b/src/poli/tests/observers/test_observers.py @@ -16,6 +16,7 @@ from poli.core.problem_setup_information import ProblemSetupInformation from poli.core.util.abstract_observer import AbstractObserver +from poli.core.util.multi_observer import MultiObserver from poli import objective_factory @@ -178,5 +179,26 @@ def test_multiple_observer_registration(): delete_observer_run_script(observer_name="simple_2__") +def test_multi_observer_works(): + obs_1 = SimpleObserver(experiment_id="example_1") + obs_2 = SimpleObserver(experiment_id="example_2") + + multi_observer = MultiObserver(observers=[obs_1, obs_2]) + + # Creating a black box function + problem = objective_factory.create(name="aloha", observer=multi_observer) + + # Evaluating the black box function + problem.black_box(np.array([list("MIGUE")])) + + assert obs_1.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}] + assert obs_2.results == [{"x": [["M", "I", "G", "U", "E"]], "y": [[0.0]]}] + + # Cleaning up (and testing whether we can access attributes + # of the external observer) + (obs_1.experiment_path / "metadata.json").unlink() + (obs_2.experiment_path / "metadata.json").unlink() + + if __name__ == "__main__": test_multiple_observer_registration()