Skip to content

Commit

Permalink
Feat/option enable_short_term in training (#151)
Browse files Browse the repository at this point in the history
* Feat/option enable_short_term in training

* bump version
  • Loading branch information
L-M-Sherlock authored Dec 17, 2024
1 parent 6882c8c commit 58f98a8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "FSRS-Optimizer"
version = "5.4.2"
version = "5.5.0"
readme = "README.md"
dependencies = [
"matplotlib>=3.7.0",
Expand Down
18 changes: 16 additions & 2 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,11 @@ def __init__(
batch_size: int = 256,
max_seq_len: int = 64,
float_delta_t: bool = False,
enable_short_term: bool = True,
) -> None:
if not enable_short_term:
init_w[17] = 0
init_w[18] = 0
self.model = FSRS(init_w, float_delta_t)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
self.clipper = ParameterClipper()
Expand All @@ -315,6 +319,7 @@ def __init__(
self.avg_eval_losses = []
self.loss_fn = nn.BCELoss(reduction="none")
self.float_delta_t = float_delta_t
self.enable_short_term = enable_short_term

def build_dataset(self, train_set: pd.DataFrame, test_set: Optional[pd.DataFrame]):
self.train_set = BatchDataset(
Expand Down Expand Up @@ -353,9 +358,12 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = self.loss_fn(retentions, labels).sum()
loss.backward()
if self.float_delta_t:
if self.float_delta_t or not self.enable_short_term:
for param in self.model.parameters():
param.grad[:4] = torch.zeros(4)
if not self.enable_short_term:
for param in self.model.parameters():
param.grad[17:19] = torch.zeros(2)
self.optimizer.step()
self.scheduler.step()
self.model.apply(self.clipper)
Expand Down Expand Up @@ -504,10 +512,14 @@ def loss(stability):

class Optimizer:
float_delta_t: bool = False
enable_short_term: bool = True

def __init__(self, float_delta_t: bool = False) -> None:
def __init__(
self, float_delta_t: bool = False, enable_short_term: bool = True
) -> None:
tqdm.pandas()
self.float_delta_t = float_delta_t
self.enable_short_term = enable_short_term
global S_MIN
S_MIN = 1e-6 if float_delta_t else 0.01

Expand Down Expand Up @@ -1197,6 +1209,7 @@ def train(
lr=lr,
batch_size=batch_size,
float_delta_t=self.float_delta_t,
enable_short_term=self.enable_short_term,
)
w.append(trainer.train(verbose=verbose))
self.w = w[-1]
Expand All @@ -1217,6 +1230,7 @@ def train(
lr=lr,
batch_size=batch_size,
float_delta_t=self.float_delta_t,
enable_short_term=self.enable_short_term,
)
w.append(trainer.train(verbose=verbose))
if verbose:
Expand Down

0 comments on commit 58f98a8

Please sign in to comment.