diff --git a/pyproject.toml b/pyproject.toml index ec2f1ac..10ad86c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "FSRS-Optimizer" -version = "5.5.0" +version = "5.6.0" readme = "README.md" dependencies = [ "matplotlib>=3.7.0", diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index b18b9af..72491a7 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -239,6 +239,10 @@ def __init__( self.seq_len = torch.tensor( dataframe["tensor"].map(len).values, dtype=torch.long ) + if "weights" in dataframe.columns: + self.weights = torch.tensor(dataframe["weights"].values, dtype=torch.float) + else: + self.weights = torch.ones(len(dataframe), dtype=torch.float) length = len(dataframe) batch_num, remainder = divmod(length, max(1, batch_size)) self.batch_num = batch_num + 1 if remainder > 0 else batch_num @@ -256,6 +260,7 @@ def __init__( self.t_train[start_index:end_index].to(device), self.y_train[start_index:end_index].to(device), seq_lens.to(device), + self.weights[start_index:end_index].to(device), ) def __getitem__(self, idx): @@ -351,12 +356,12 @@ def train(self, verbose: bool = True): for i, batch in enumerate(self.train_data_loader): self.model.train() self.optimizer.zero_grad() - sequences, delta_ts, labels, seq_lens = batch + sequences, delta_ts, labels, seq_lens, weights = batch real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) - loss = self.loss_fn(retentions, labels).sum() + loss = (self.loss_fn(retentions, labels) * weights).sum() loss.backward() if self.float_delta_t or not self.enable_short_term: for param in self.model.parameters(): @@ -395,17 +400,18 @@ def eval(self): if len(dataset) == 0: losses.append(0) continue - sequences, delta_ts, labels, seq_lens = ( + sequences, delta_ts, labels, seq_lens, weights = ( dataset.x_train, dataset.t_train, dataset.y_train, dataset.seq_len, + dataset.weights, ) real_batch_size = seq_lens.shape[0] outputs, _ = self.model(sequences.transpose(0, 1)) stabilities = outputs[seq_lens - 1, torch.arange(real_batch_size), 0] retentions = power_forgetting_curve(delta_ts, stabilities) - loss = self.loss_fn(retentions, labels).mean() + loss = (self.loss_fn(retentions, labels) * weights).mean() losses.append(loss) self.avg_train_losses.append(losses[0]) self.avg_eval_losses.append(losses[1]) @@ -1181,6 +1187,7 @@ def train( batch_size: int = 512, verbose: bool = True, split_by_time: bool = False, + recency_weight: bool = False, ): """Step 4""" self.dataset["tensor"] = self.dataset.progress_apply( @@ -1193,9 +1200,9 @@ def train( w = [] plots = [] + self.dataset.sort_values(by=["review_time"], inplace=True) if split_by_time: tscv = TimeSeriesSplit(n_splits=5) - self.dataset.sort_values(by=["review_time"], inplace=True) for i, (train_index, test_index) in enumerate(tscv.split(self.dataset)): if verbose: tqdm.write(f"TRAIN: {len(train_index)} TEST: {len(test_index)}") @@ -1222,6 +1229,8 @@ def train( print(metrics) plots.append(trainer.plot()) else: + if recency_weight: + self.dataset["weights"] = np.linspace(0.5, 1.5, len(self.dataset)) trainer = Trainer( self.dataset, None,