Skip to content

Commit

Permalink
refactor solutions
Browse files Browse the repository at this point in the history
  • Loading branch information
adzcai committed Jun 29, 2024
1 parent f3a1230 commit 0d25179
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 49 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ _build/
.DS_Store
.ipynb_checkpoints/

# exercise solutions
solutions/

# autogenerated Jupyter Book config for debugging
conf.py

# use MyST markdown as authoritative source
*.ipynb

# latex
# latex build outputs
*.aux
*.bbl
*.bcf
Expand Down
8 changes: 5 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ _META = \

META = $(addsuffix .md, $(addprefix book/, $(_META)))

CHAPTERS = $(NOTEBOOKS) $(META)
SOLUTIONS = book/solutions/bandits.py

SOURCE = $(NOTEBOOKS) $(META) $(SOLUTIONS)

CONFIG = book/_config.yml book/_toc.yml

book/_build/html: $(CHAPTERS) $(CONFIG)
book/_build/html: $(SOURCE) $(CONFIG)
$(RUN) jb build book

open: book/_build/html
open book/_build/html/index.html

book/_build/latex: $(CHAPTERS) $(CONFIG)
book/_build/latex: $(SOURCE) $(CONFIG)
$(RUN) jb build book --builder latex

pdf: book/_build/latex
Expand Down
54 changes: 9 additions & 45 deletions book/bandits.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ from abc import ABC, abstractmethod # "Abstract Base Class"
from typing import Callable, Union
import matplotlib.pyplot as plt
import solutions.bandits as solutions
np.random.seed(184)
# output_notebook() # set up bokeh
Expand All @@ -38,14 +40,6 @@ def random_argmax(ary: Array) -> int:
return np.random.choice(max_idx).item()
def choose_zero(ary: Float[Array, " K"]) -> Union[int, None]:
min_idx = np.flatnonzero(ary == 0)
if min_idx.size > 0:
return np.random.choice(min_idx).item()
else:
return None
latex = latexify.algorithmic(
prefixes={"mab"},
identifiers={"arm": "a_t", "reward": "r", "means": "mu"},
Expand Down Expand Up @@ -205,6 +199,8 @@ The rest of the chapter comprises a series of increasingly sophisticated
MAB algorithms.

```{code-cell}
:tags: [hide-input]
def plot_strategy(mab: MAB, agent: Agent):
plt.figure(figsize=(10, 6))
Expand Down Expand Up @@ -236,7 +232,7 @@ exploration").
class PureExploration(Agent):
def choose_arm(self):
"""Choose an arm uniformly at random."""
return np.random.randint(self.K)
return solutions.pure_exploration_choose_arm(self)
```

Note that
Expand Down Expand Up @@ -274,15 +270,7 @@ call this the **pure greedy** strategy.
class PureGreedy(Agent):
def choose_arm(self):
"""Choose the arm with the highest observed reward on its first pull."""
if self.count < self.K:
# first K steps: choose each arm once
return self.count
if self.count == self.K:
# after the first K steps, choose the arm with the highest observed reward
self.greedy_arm = random_argmax(self.history[:, 1])
return self.greedy_arm
return solutions.pure_greedy_choose_arm(self)
```

Note we’ve used superscripts $r^k$ during the exploration phase to
Expand Down Expand Up @@ -332,15 +320,7 @@ class ExploreThenCommit(Agent):
self.N_explore = N_explore
def choose_arm(self):
if self.count < self.K * self.N_explore:
# exploration phase: choose each arm N_explore times
return self.count // self.N_explore
# exploitation phase: choose the arm with the highest observed reward
if self.count == self.K * self.N_explore:
self.greedy_arm = random_argmax(self.history[:, 1])
return self.greedy_arm
return solutions.etc_choose_arm(self)
```

```{code-cell}
Expand Down Expand Up @@ -493,16 +473,7 @@ class EpsilonGreedy(Agent):
self.get_epsilon = get_epsilon
def choose_arm(self):
epsilon = self.get_epsilon(self.count)
if np.random.random() < epsilon:
return np.random.randint(0, self.K - 1)
else:
counts = self.history.sum(axis=1)
unvisited = choose_zero(counts)
if unvisited is not None:
return unvisited
sample_means = self.history[:, 1] / counts
return random_argmax(sample_means)
return solutions.epsilon_greedy_choose_arm(self)
```

```{code-cell}
Expand Down Expand Up @@ -600,14 +571,7 @@ class UCB(Agent):
self.delta = delta
def choose_arm(self):
counts = self.history.sum(axis=1)
unvisited = choose_zero(counts)
if unvisited is not None:
return unvisited
sample_means = self.history[:, 1] / counts
bounds = np.sqrt(np.log(2 * self.count / self.delta) / (2 * counts))
ucbs = sample_means + bounds
return random_argmax(ucbs)
return solutions.ucb_choose_arm(self)
```

Intuitively, UCB prioritizes arms where:
Expand Down

0 comments on commit 0d25179

Please sign in to comment.