Skip to content

Commit

Permalink
Feat/Add Review Logs to Simulator Output and Optimize Brier Plot Disp…
Browse files Browse the repository at this point in the history
…lay (#148)

* Feat/return review logs from simulator

* Feat/return review logs from simulator

* bump version
  • Loading branch information
L-M-Sherlock authored Dec 6, 2024
1 parent ae7575c commit 0b880f9
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.3.0"
version = "5.4.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
15 changes: 5 additions & 10 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,7 @@ def find_optimal_retention(
learn_cnt_per_day,
memorized_cnt_per_day,
cost_per_day,
_,
) = simulate(**simulate_config)

def moving_average(data, window_size=365 // 20):
Expand Down Expand Up @@ -1873,12 +1874,6 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
bin_prediction_means[mask],
sample_weight=bin_counts[mask],
)
tqdm.write(f"R-squared: {r2:.4f}")
tqdm.write(f"MAE: {mae:.4f}")
tqdm.write(f"ICI: {ici:.4f}")
tqdm.write(f"E50: {e_50:.4f}")
tqdm.write(f"E90: {e_90:.4f}")
tqdm.write(f"EMax: {e_max:.4f}")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
Expand All @@ -1888,12 +1883,12 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
sm.add_constant(bin_prediction_means[mask]),
weights=bin_counts[mask],
).fit()
tqdm.write(str(fit_wls.params))
y_regression = [fit_wls.params[0] + fit_wls.params[1] * x for x in [0, 1]]
params = fit_wls.params
y_regression = [params[0] + params[1] * x for x in [0, 1]]
ax.plot(
[0, 1],
y_regression,
label="Weighted Least Squares Regression",
label=f"y = {params[0]:.3f} + {params[1]:.3f}x",
color="green",
)
except:
Expand Down Expand Up @@ -1930,7 +1925,7 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
ax2.legend(loc="lower center")
if title:
ax.set_title(title)
metrics = {"R-squared": r2, "MAE": mae, "ICI": ici}
metrics = {"R-squared": r2, "MAE": mae, "ICI": ici, "E50": e_50, "E90": e_90, "EMax": e_max}
return metrics


Expand Down
11 changes: 9 additions & 2 deletions src/fsrs_optimizer/fsrs_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def simulate(
)
card_table[col["rating"]] = card_table[col["rating"]].astype(int)

revlogs = {}
review_cnt_per_day = np.zeros(learn_span)
learn_cnt_per_day = np.zeros(learn_span)
memorized_cnt_per_day = np.zeros(learn_span)
Expand Down Expand Up @@ -231,6 +232,11 @@ def mean_reversion(init, current):
today + card_table[col["ivl"]][true_review | true_learn]
)

revlogs[today] = {
"card_id": np.where(true_review | true_learn)[0],
"rating": card_table[col["rating"]][true_review | true_learn],
}

review_cnt_per_day[today] = np.sum(true_review)
learn_cnt_per_day[today] = np.sum(true_learn)
memorized_cnt_per_day[today] = card_table[col["retrievability"]].sum()
Expand All @@ -241,6 +247,7 @@ def mean_reversion(init, current):
learn_cnt_per_day,
memorized_cnt_per_day,
cost_per_day,
revlogs,
)


Expand Down Expand Up @@ -284,7 +291,7 @@ def best_sample_size(days_to_simulate):
SAMPLE_SIZE = best_sample_size(learn_span)

for i in range(SAMPLE_SIZE):
_, _, _, memorized_cnt_per_day, cost_per_day = simulate(
_, _, _, memorized_cnt_per_day, cost_per_day, _ = simulate(
w,
r,
deck_size,
Expand Down Expand Up @@ -629,7 +636,7 @@ def workload_graph(default_params, sampling_size=30):
"review_limit_perday": math.inf,
"max_ivl": 36500,
}
(_, review_cnt_per_day, learn_cnt_per_day, memorized_cnt_per_day, _) = simulate(
(_, review_cnt_per_day, learn_cnt_per_day, memorized_cnt_per_day, _, _) = simulate(
w=default_params["w"],
max_cost_perday=math.inf,
learn_limit_perday=10,
Expand Down
1 change: 1 addition & 0 deletions tests/simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def test_simulate(self):
learn_cnt_per_day,
memorized_cnt_per_day,
cost_per_day,
revlogs,
) = simulate(w=DEFAULT_PARAMETER, request_retention=0.9)
assert memorized_cnt_per_day[-1] == 5875.025236206539

Expand Down

0 comments on commit 0b880f9

Please sign in to comment.