Skip to content

Commit

Permalink
Feat/speed up find optimal retention (#307)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jun 18, 2023
1 parent 8fcf038 commit 9550354
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 307 deletions.
576 changes: 288 additions & 288 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

37 changes: 19 additions & 18 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def find_optimal_retention(self):
d_offset = 1
r_time = 8
f_time = 25
max_time = 200000
max_time = 1e10

type_block = dict()
type_count = dict()
Expand All @@ -590,7 +590,7 @@ def find_optimal_retention(self):
print(f"average time for recalled cards: {r_time}s")

def stability2index(stability):
return int(round(np.log(stability) / np.log(base)) + index_offset)
return (np.log(stability) / np.log(base)).round().astype(int) + index_offset

def init_stability(d):
return max(((d - self.w[2]) / self.w[3] + 2) * self.w[1] + self.w[0], np.power(base, -index_offset))
Expand All @@ -614,23 +614,24 @@ def cal_next_recall_stability(s, r, d, response):
s0 = init_stability(d)
s0_index = stability2index(s0)
diff = max_time
while diff > 0.1:
while diff > 1:
total_time = time_list[d - 1].sum()
s_indices = np.arange(index_len - 2, -1, -1)
stabilities = stability_list[s_indices]
intervals = np.maximum(1, np.round(stabilities * np.log(recall) / np.log(0.9)))
p_recalls = np.power(0.9, intervals / stabilities)
recall_s = cal_next_recall_stability(stabilities, p_recalls, d, 1)
forget_d = np.minimum(d + d_offset, 10)
forget_s = cal_next_recall_stability(stabilities, p_recalls, forget_d, 0)
recall_s_indices = np.minimum(stability2index(recall_s), index_len - 1)
forget_s_indices = np.clip(stability2index(forget_s), 0, index_len - 1)
recall_times = time_list[d - 1][recall_s_indices] + r_time
forget_times = time_list[forget_d - 1][forget_s_indices] + f_time
exp_times = p_recalls * recall_times + (1.0 - p_recalls) * forget_times
mask = exp_times < time_list[d - 1][s_indices]
time_list[d - 1][s_indices[mask]] = exp_times[mask]
diff = total_time - time_list[d - 1].sum()
s0_time = time_list[d - 1][s0_index]
for s_index in range(index_len - 2, -1, -1):
stability = stability_list[s_index];
interval = max(1, round(stability * np.log(recall) / np.log(0.9)))
p_recall = np.power(0.9, interval / stability)
recall_s = cal_next_recall_stability(stability, p_recall, d, 1)
forget_d = min(d + d_offset, 10)
forget_s = cal_next_recall_stability(stability, p_recall, forget_d, 0)
recall_s_index = min(stability2index(recall_s), index_len - 1)
forget_s_index = min(max(stability2index(forget_s), 0), index_len - 1)
recall_time = time_list[d - 1][recall_s_index] + r_time
forget_time = time_list[forget_d - 1][forget_s_index] + f_time
exp_time = p_recall * recall_time + (1.0 - p_recall) * forget_time
if exp_time < time_list[d - 1][s_index]:
time_list[d - 1][s_index] = exp_time
diff = s0_time - time_list[d - 1][s0_index]
df.loc[0 if pd.isnull(df.index.max()) else df.index.max() + 1] = [recall, d, s0_time]

df.sort_values(by=["difficulty", "retention"], inplace=True)
Expand Down
2 changes: 1 addition & 1 deletion package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fsrs4anki_optimizer"
version = "3.24.6"
version = "3.25.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down

0 comments on commit 9550354

Please sign in to comment.