Skip to content

Commit

Permalink
Fix/calibration skip empty ratings (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jul 20, 2023
1 parent 2b64b93 commit 558bbf8
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 330 deletions.
644 changes: 321 additions & 323 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

36 changes: 30 additions & 6 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import curve_fit
from itertools import accumulate
from tqdm.auto import tqdm
Expand Down Expand Up @@ -399,6 +399,25 @@ def cum_concat(x):
df['r_history']=[','.join(map(str, item[:-1])) for sublist in r_history for item in sublist]
df = df.groupby('cid').filter(lambda group: group['id'].min() > time.mktime(datetime.strptime(revlog_start_date, "%Y-%m-%d").timetuple()) * 1000)
df['y'] = df['r'].map(lambda x: {1: 0, 2: 1, 3: 1, 4: 1}[x])

# def remove_outliers(group: pd.DataFrame) -> pd.DataFrame:
# threshold = np.mean(group['delta_t']) * 1.5
# # threshold = group['delta_t'].quantile(0.95)
# group = group[group['delta_t'] < threshold]
# return group

# df = df.groupby(by=['r_history', 't_history'], as_index=False, group_keys=False).apply(remove_outliers)

# def remove_non_continuous_rows(group):
# discontinuity = group['i'].diff().fillna(1).ne(1)
# if not discontinuity.any():
# return group
# else:
# first_non_continuous_index = discontinuity.idxmax()
# return group.loc[:first_non_continuous_index-1]

# df = df.groupby('cid', as_index=False, group_keys=False).progress_apply(remove_non_continuous_rows)

df.to_csv('revlog_history.tsv', sep="\t", index=False)
tqdm.write("Trainset saved.")

Expand Down Expand Up @@ -494,7 +513,7 @@ def pretrain(self, verbose=True):
if total_count < 100:
tqdm.write(f'Not enough data for first rating {first_rating}. Expected at least 100, got {total_count}.')
continue
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/count, bounds=((0.1), (60 if total_count < 1000 else 365)))
params, _ = curve_fit(power_forgetting_curve, delta_t, recall, sigma=1/np.sqrt(count), bounds=((0.1), (60 if total_count < 1000 else 365)))
stability = params[0]
rating_stability[int(first_rating)] = stability
rating_count[int(first_rating)] = total_count
Expand Down Expand Up @@ -527,7 +546,7 @@ def pretrain(self, verbose=True):
def S0_rating_curve(rating, a, b, c):
return np.exp(a + b * rating) + c

params, covs = curve_fit(S0_rating_curve, list(rating_stability.keys()), list(rating_stability.values()), sigma=1/np.array(list(rating_count.values())), method='dogbox', bounds=((-15, 0.03, -5), (15, 7, 30)))
params, covs = curve_fit(S0_rating_curve, list(rating_stability.keys()), list(rating_stability.values()), sigma=1/np.sqrt(list(rating_count.values())), method='dogbox', bounds=((-15, 0.03, -5), (15, 7, 30)))
if verbose:
tqdm.write(f'Weighted fit parameters: {params}')
predict_stability = S0_rating_curve(np.array(list(rating_stability.keys())), *params)
Expand Down Expand Up @@ -798,8 +817,11 @@ def calibration_graph(self):
plot_brier(self.dataset['p'], self.dataset['y'], bins=40, ax=fig1.add_subplot(111))
fig2 = plt.figure(figsize=(16, 12))
for last_rating in ("1","2","3","4"):
calibration_data = self.dataset[self.dataset['r_history'].str.endswith(last_rating)]
if calibration_data.empty:
continue
tqdm.write(f"\nLast rating: {last_rating}")
plot_brier(self.dataset[self.dataset['r_history'].str.endswith(last_rating)]['p'], self.dataset[self.dataset['r_history'].str.endswith(last_rating)]['y'], bins=40, ax=fig2.add_subplot(2, 2, int(last_rating)), title=f"Last rating: {last_rating}")
plot_brier(calibration_data['p'], calibration_data['y'], bins=40, ax=fig2.add_subplot(2, 2, int(last_rating)), title=f"Last rating: {last_rating}")

def to_percent(temp, position):
return '%1.0f' % (100 * temp) + '%'
Expand Down Expand Up @@ -933,7 +955,7 @@ def load_brier(predictions, real, bins=20):
prediction_means = prediction / counts
prediction_means[np.isnan(prediction_means)] = ((np.arange(bins) + 0.5) / bins)[np.isnan(prediction_means)]
correct_means = correct / counts
correct_means[np.isnan(correct_means)] = 0
correct_means[np.isnan(correct_means)] = ((np.arange(bins) + 0.5) / bins)[np.isnan(correct_means)]
size = len(predictions)
answer_mean = sum(correct) / size
return {
Expand All @@ -954,9 +976,11 @@ def plot_brier(predictions, real, bins=20, ax=None, title=None):
bin_correct_means = brier['detail']['bin_correct_means']
bin_counts = brier['detail']['bin_counts']
r2 = r2_score(bin_correct_means, bin_prediction_means, sample_weight=bin_counts)
rmse = np.sqrt(mean_squared_error(bin_correct_means, bin_prediction_means, sample_weight=bin_counts))
rmse = mean_squared_error(bin_correct_means, bin_prediction_means, sample_weight=bin_counts, squared=False)
mae = mean_absolute_error(bin_correct_means, bin_prediction_means, sample_weight=bin_counts)
tqdm.write(f"R-squared: {r2:.4f}")
tqdm.write(f"RMSE: {rmse:.4f}")
tqdm.write(f"MAE: {mae:.4f}")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.grid(True)
Expand Down
2 changes: 1 addition & 1 deletion package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fsrs4anki_optimizer"
version = "4.1.2"
version = "4.1.3"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down

0 comments on commit 558bbf8

Please sign in to comment.