From 33c6ae7ba6d90b6cb2629c4f2f3d96b8195560c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miguel=20Gonz=C3=A1lez=20Duque?= Date: Thu, 7 Mar 2024 21:23:25 +0100 Subject: [PATCH] Fixes a couple of bugs on negative and multiobjective black boxes (#164) --- src/poli/core/abstract_black_box.py | 7 +++- src/poli/core/multi_objective_black_box.py | 25 ++++++++++---- .../test_multi_objective_and_negative.py | 33 +++++++++++++++++++ 3 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 src/poli/tests/registry/test_multi_objective_and_negative.py diff --git a/src/poli/core/abstract_black_box.py b/src/poli/core/abstract_black_box.py index 342677a1..b326f137 100644 --- a/src/poli/core/abstract_black_box.py +++ b/src/poli/core/abstract_black_box.py @@ -362,7 +362,12 @@ class NegativeBlackBox(AbstractBlackBox): def __init__(self, f: AbstractBlackBox): self.f = f - super().__init__(info=f.info, batch_size=f.batch_size) + super().__init__( + batch_size=f.batch_size, + parallelize=f.parallelize, + num_workers=f.num_workers, + evaluation_budget=f.evaluation_budget, + ) def __call__(self, x, context=None): return -self.f.__call__(x, context) diff --git a/src/poli/core/multi_objective_black_box.py b/src/poli/core/multi_objective_black_box.py index ba379d86..bf6df1ed 100644 --- a/src/poli/core/multi_objective_black_box.py +++ b/src/poli/core/multi_objective_black_box.py @@ -10,6 +10,7 @@ from poli.core.abstract_black_box import AbstractBlackBox from poli.core.problem_setup_information import ProblemSetupInformation +from poli.core.black_box_information import BlackBoxInformation class MultiObjectiveBlackBox(AbstractBlackBox): @@ -20,8 +21,6 @@ class MultiObjectiveBlackBox(AbstractBlackBox): Parameters ----------- - info : ProblemSetupInformation - The problem setup information. batch_size : int, optional The batch size for evaluating the black box function. Defaults to None. objective_functions : List[AbstractBlackBox], required @@ -51,7 +50,6 @@ class MultiObjectiveBlackBox(AbstractBlackBox): def __init__( self, - info: ProblemSetupInformation, objective_functions: List[AbstractBlackBox], batch_size: int = None, ) -> None: @@ -60,8 +58,6 @@ def __init__( Parameters ----------- - info : ProblemSetupInformation - The problem setup information. objective_functions : List[AbstractBlackBox] The list of objective functions. batch_size : int, optional @@ -77,7 +73,9 @@ def __init__( "objective_functions must be provided as a list of AbstractBlackBox instances or inherited classes." ) - super().__init__(info=info, batch_size=batch_size) + super().__init__( + batch_size=batch_size, + ) self.objective_functions = objective_functions @@ -108,3 +106,18 @@ def __str__(self) -> str: def __repr__(self) -> str: return f"" + + @property + def info(self) -> BlackBoxInformation: + """ + Return the problem setup information for the multi-objective black box. + + Returns + ------- + BlackBoxInformation: + Information about the first objective function. + """ + # TODO: what should this return, actually? + # I'd say that we expect all objective functions to be able + # to take the same input. + return self.objective_functions[0].info diff --git a/src/poli/tests/registry/test_multi_objective_and_negative.py b/src/poli/tests/registry/test_multi_objective_and_negative.py new file mode 100644 index 00000000..6dde8c3c --- /dev/null +++ b/src/poli/tests/registry/test_multi_objective_and_negative.py @@ -0,0 +1,33 @@ +import numpy as np + + +def test_multi_objective_instantiation(): + from poli.objective_repository import AlohaBlackBox + from poli.core.multi_objective_black_box import MultiObjectiveBlackBox + + f_aloha = AlohaBlackBox() + + f = MultiObjectiveBlackBox( + objective_functions=[f_aloha, f_aloha], + ) + + assert f.objective_functions == [f_aloha, f_aloha] + + x0 = np.array([["A", "B", "C", "D", "E"]]) + y0 = f(x0) + + assert y0.shape == (1, 2) + + +def test_negative_black_boxes(): + from poli.objective_repository import AlohaBlackBox + + f = AlohaBlackBox() + g = -f + + x0 = np.array([["A", "B", "C", "D", "E"]]) + + f0 = f(x0) + g0 = g(x0) + + assert f0 == -g0