Skip to content

Commit

Permalink
Change MAB experimenters to use more general arm_to_rewards diction…
Browse files Browse the repository at this point in the history
…ary.

PiperOrigin-RevId: 697654482
  • Loading branch information
xingyousong authored and copybara-github committed Nov 18, 2024
1 parent 3a53389 commit 10fa07c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def testNormalizationApply(self, func):

def test_NormalizingCategoricals(self):
mab_exptr = multiarm.FixedMultiArmExperimenter(
rewards=[-1e6, 0.0, 1e6], arms_as_chars=False
arms_to_rewards={'0': -1e6, '1': 0.0, '2': 1e6}
)
norm_exptr = normalizing_experimenter.NormalizingExperimenter(mab_exptr)
metric_name = norm_exptr.problem_statement().metric_information.item().name
Expand Down
47 changes: 17 additions & 30 deletions vizier/_src/benchmarks/experimenters/synthetic/multiarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,20 @@
distributions.
"""

import copy
from typing import Optional, Sequence
from typing import Mapping, Optional, Sequence

import numpy as np
from vizier import pyvizier as vz
from vizier._src.benchmarks.experimenters import experimenter


def _default_multiarm_problem(
num_arms: int, arms_as_chars: bool
) -> vz.ProblemStatement:
def _default_multiarm_problem(arms: Sequence[str]) -> vz.ProblemStatement:
"""Returns default multi-arm problem statement."""
problem = vz.ProblemStatement()
problem.metric_information.append(
vz.MetricInformation(name="reward", goal=vz.ObjectiveMetricGoal.MAXIMIZE)
)

if arms_as_chars:
# Starts with 'a' character.
feasible_values = [chr(i + 97) for i in range(num_arms)]
else:
feasible_values = [str(i) for i in range(num_arms)]

problem.search_space.root.add_categorical_param(
name="arm", feasible_values=feasible_values
)
problem.search_space.root.add_categorical_param("arm", feasible_values=arms)
return problem


Expand All @@ -54,40 +42,39 @@ class BernoulliMultiArmExperimenter(experimenter.Experimenter):

def __init__(
self,
probs: Sequence[float],
arms_as_chars: bool = True,
arms_to_probs: Mapping[str, float],
seed: Optional[int] = None,
):
self._probs = probs
if sum(arms_to_probs.values()) != 1.0:
raise ValueError(
"Sum of probabilities must be 1, got %s" % sum(arms_to_probs.values())
)
self._arms_to_probs = arms_to_probs
self._rng = np.random.RandomState(seed)
self._problem = _default_multiarm_problem(len(self._probs), arms_as_chars)

def problem_statement(self) -> vz.ProblemStatement:
return copy.deepcopy(self._problem)
return _default_multiarm_problem(list(self._arms_to_probs.keys()))

def evaluate(self, suggestions: Sequence[vz.Trial]) -> None:
"""Each arm has a fixed probability of outputting 0 or 1 reward."""
feasibles = self._problem.search_space.parameters[0].feasible_values
for suggestion in suggestions:
arm_index = feasibles.index(suggestion.parameters["arm"].value)
prob = self._probs[arm_index]
arm = suggestion.parameters["arm"].value
prob = self._arms_to_probs[arm]
reward = self._rng.choice([0, 1], p=[1 - prob, prob])
suggestion.final_measurement = vz.Measurement(metrics={"reward": reward})


class FixedMultiArmExperimenter(experimenter.Experimenter):
"""Rewards are deterministic."""

def __init__(self, rewards: Sequence[float], arms_as_chars: bool = True):
self._rewards = rewards
self._problem = _default_multiarm_problem(len(self._rewards), arms_as_chars)
def __init__(self, arms_to_rewards: Mapping[str, float]):
self._arms_to_rewards = arms_to_rewards

def problem_statement(self) -> vz.ProblemStatement:
return copy.deepcopy(self._problem)
return _default_multiarm_problem(list(self._arms_to_rewards.keys()))

def evaluate(self, suggestions: Sequence[vz.Trial]) -> None:
feasibles = self._problem.search_space.parameters[0].feasible_values
for suggestion in suggestions:
arm_index = feasibles.index(suggestion.parameters["arm"].value)
reward = self._rewards[arm_index]
arm = suggestion.parameters["arm"].value
reward = self._arms_to_rewards[arm]
suggestion.final_measurement = vz.Measurement(metrics={"reward": reward})

0 comments on commit 10fa07c

Please sign in to comment.