From 0d25179f190ac4784d673e2ed8dcdb6d4a2c66f9 Mon Sep 17 00:00:00 2001 From: Alexander Cai Date: Sat, 29 Jun 2024 12:59:10 +0200 Subject: [PATCH] refactor solutions --- .gitignore | 5 ++++- Makefile | 8 +++++--- book/bandits.md | 54 +++++++++---------------------------------------- 3 files changed, 18 insertions(+), 49 deletions(-) diff --git a/.gitignore b/.gitignore index d8c7662..d0cfe25 100644 --- a/.gitignore +++ b/.gitignore @@ -2,13 +2,16 @@ _build/ .DS_Store .ipynb_checkpoints/ +# exercise solutions +solutions/ + # autogenerated Jupyter Book config for debugging conf.py # use MyST markdown as authoritative source *.ipynb -# latex +# latex build outputs *.aux *.bbl *.bcf diff --git a/Makefile b/Makefile index 5024409..3b20ae7 100644 --- a/Makefile +++ b/Makefile @@ -16,17 +16,19 @@ _META = \ META = $(addsuffix .md, $(addprefix book/, $(_META))) -CHAPTERS = $(NOTEBOOKS) $(META) +SOLUTIONS = book/solutions/bandits.py + +SOURCE = $(NOTEBOOKS) $(META) $(SOLUTIONS) CONFIG = book/_config.yml book/_toc.yml -book/_build/html: $(CHAPTERS) $(CONFIG) +book/_build/html: $(SOURCE) $(CONFIG) $(RUN) jb build book open: book/_build/html open book/_build/html/index.html -book/_build/latex: $(CHAPTERS) $(CONFIG) +book/_build/latex: $(SOURCE) $(CONFIG) $(RUN) jb build book --builder latex pdf: book/_build/latex diff --git a/book/bandits.md b/book/bandits.md index 8f59a6c..56c875c 100644 --- a/book/bandits.md +++ b/book/bandits.md @@ -26,6 +26,8 @@ from abc import ABC, abstractmethod # "Abstract Base Class" from typing import Callable, Union import matplotlib.pyplot as plt +import solutions.bandits as solutions + np.random.seed(184) # output_notebook() # set up bokeh @@ -38,14 +40,6 @@ def random_argmax(ary: Array) -> int: return np.random.choice(max_idx).item() -def choose_zero(ary: Float[Array, " K"]) -> Union[int, None]: - min_idx = np.flatnonzero(ary == 0) - if min_idx.size > 0: - return np.random.choice(min_idx).item() - else: - return None - - latex = latexify.algorithmic( prefixes={"mab"}, identifiers={"arm": "a_t", "reward": "r", "means": "mu"}, @@ -205,6 +199,8 @@ The rest of the chapter comprises a series of increasingly sophisticated MAB algorithms. ```{code-cell} +:tags: [hide-input] + def plot_strategy(mab: MAB, agent: Agent): plt.figure(figsize=(10, 6)) @@ -236,7 +232,7 @@ exploration"). class PureExploration(Agent): def choose_arm(self): """Choose an arm uniformly at random.""" - return np.random.randint(self.K) + return solutions.pure_exploration_choose_arm(self) ``` Note that @@ -274,15 +270,7 @@ call this the **pure greedy** strategy. class PureGreedy(Agent): def choose_arm(self): """Choose the arm with the highest observed reward on its first pull.""" - if self.count < self.K: - # first K steps: choose each arm once - return self.count - - if self.count == self.K: - # after the first K steps, choose the arm with the highest observed reward - self.greedy_arm = random_argmax(self.history[:, 1]) - - return self.greedy_arm + return solutions.pure_greedy_choose_arm(self) ``` Note we’ve used superscripts $r^k$ during the exploration phase to @@ -332,15 +320,7 @@ class ExploreThenCommit(Agent): self.N_explore = N_explore def choose_arm(self): - if self.count < self.K * self.N_explore: - # exploration phase: choose each arm N_explore times - return self.count // self.N_explore - - # exploitation phase: choose the arm with the highest observed reward - if self.count == self.K * self.N_explore: - self.greedy_arm = random_argmax(self.history[:, 1]) - - return self.greedy_arm + return solutions.etc_choose_arm(self) ``` ```{code-cell} @@ -493,16 +473,7 @@ class EpsilonGreedy(Agent): self.get_epsilon = get_epsilon def choose_arm(self): - epsilon = self.get_epsilon(self.count) - if np.random.random() < epsilon: - return np.random.randint(0, self.K - 1) - else: - counts = self.history.sum(axis=1) - unvisited = choose_zero(counts) - if unvisited is not None: - return unvisited - sample_means = self.history[:, 1] / counts - return random_argmax(sample_means) + return solutions.epsilon_greedy_choose_arm(self) ``` ```{code-cell} @@ -600,14 +571,7 @@ class UCB(Agent): self.delta = delta def choose_arm(self): - counts = self.history.sum(axis=1) - unvisited = choose_zero(counts) - if unvisited is not None: - return unvisited - sample_means = self.history[:, 1] / counts - bounds = np.sqrt(np.log(2 * self.count / self.delta) / (2 * counts)) - ucbs = sample_means + bounds - return random_argmax(ucbs) + return solutions.ucb_choose_arm(self) ``` Intuitively, UCB prioritizes arms where: