-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b28ce1a
commit 5ff362c
Showing
5 changed files
with
249 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .train import ScCoralTrainingPlan as _training_plan |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from typing import Literal | ||
|
||
from lightning import LightningModule, Trainer | ||
from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping | ||
from torch.optim.optimizer import Optimizer | ||
|
||
|
||
class EarlyStoppingCheck(EarlyStopping): | ||
"""Check if early stopping condition is met but do not interrupt training | ||
Modified `lightning.pytorch.callbacks.EarlyStopping` class that | ||
instead of sending early stopping signal to `Trainer` sets the | ||
parameter `pretraining_early_stopping_condition` in `TrainingPlan` | ||
to true | ||
Parameters | ||
---------- | ||
monitor | ||
Which loss to monitor. | ||
min_delta | ||
Definition of converging loss | ||
patience | ||
Number of consequetive epochs to wait until we send a stopping signal | ||
mode | ||
Look for maximum or minimum | ||
check_on_train | ||
Whether to check on training epoch end or validation epoch end. Defaults | ||
to training epoch | ||
**kwargs | ||
Other arguments passed to `lightning.pytorch.callbacks.EarlyStopping` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
monitor="reconstruction_loss", | ||
min_delta: float = 0.0, | ||
patience: int = 5, | ||
mode: Literal["max", "min"] = "min", | ||
check_on_train: bool = True, | ||
**kwargs, | ||
): | ||
super().__init__(monitor, min_delta, patience, mode, **kwargs) | ||
|
||
self.check_on_train = check_on_train | ||
|
||
self.state = {} | ||
|
||
def _run_early_stopping_check(self, trainer: Trainer, pl_module: LightningModule): | ||
"""Overwrite method that stops trainer""" | ||
pass | ||
|
||
def _check_stopping(self, trainer: Trainer, pl_module: LightningModule): | ||
logs = trainer.callback_metrics | ||
current = logs[self.monitor].squeeze() | ||
should_stop, reason = self._evaluate_stopping_criteria(current) | ||
|
||
# Write if model should stop | ||
pl_module.pretraining_early_stopping_condition = should_stop | ||
|
||
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if not self.check_on_train: | ||
return | ||
self._check_stopping(trainer, pl_module) | ||
|
||
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule): | ||
if self.check_on_train: | ||
return | ||
self._check_stopping(trainer, pl_module) | ||
|
||
|
||
class PretrainingFreezeWeights(BaseFinetuning): | ||
"""Freeze weights of parts of the module until pretraining ends | ||
Parameters | ||
---------- | ||
submodule | ||
For which part of the model we would like to freeze the weights during pretraining | ||
n_pretraining_epochs | ||
Maximal number of pretraining epochs | ||
early_stopping | ||
Whether to use `EarlyStoppingCheck` as additional stopping metric | ||
**kwargs | ||
Other keyword arguments passed to `lightning.pytorch.callbacks.BaseFinetuning` | ||
""" | ||
|
||
def __init__(self, submodule: str = "z_encoder", n_pretraining_epochs: int = 500, early_stopping=True, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.n_pretraining_epochs = n_pretraining_epochs | ||
self.early_stopping = early_stopping | ||
self.submodule = submodule | ||
|
||
def freeze_before_training(self, pl_module: LightningModule) -> None: | ||
module = getattr(pl_module.module, self.submodule) | ||
self.freeze(module) | ||
|
||
def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer) -> None: | ||
if pl_module.is_pretrained: # skip if pretraining is finished | ||
return | ||
early_stopping_condition = False | ||
if self.early_stopping: | ||
early_stopping_condition = pl_module.pretraining_early_stopping_condition | ||
if (epoch == self.n_pretraining_epochs) or early_stopping_condition: | ||
self.unfreeze_and_add_param_group( | ||
modules=getattr(pl_module.module, self.submodule), optimizer=optimizer, train_bn=True | ||
) | ||
pl_module.is_pretrained = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from collections.abc import Iterable | ||
from typing import Callable | ||
|
||
import optax | ||
import torch | ||
from scvi.train import TrainingPlan | ||
|
||
axOptimizerCreator = Callable[[], optax.GradientTransformation] | ||
TorchOptimizerCreator = Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer] | ||
|
||
|
||
class ScCoralTrainingPlan(TrainingPlan): | ||
"""Implement custom pretraining procedure for sccoral""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self._is_pretrained = False | ||
self._pretraining_early_stopping_condition = False | ||
|
||
@property | ||
def is_pretrained(self): | ||
"""If model is pretrained""" | ||
return self._is_pretrained | ||
|
||
@is_pretrained.setter | ||
def is_pretrained(self, value: bool): | ||
self._is_pretrained = value | ||
|
||
@property | ||
def pretraining_early_stopping_condition(self): | ||
"""If pretraining should stop""" | ||
return self._pretraining_early_stopping_condition | ||
|
||
@pretraining_early_stopping_condition.setter | ||
def pretraining_early_stopping_condition(self, value: bool): | ||
self._pretraining_early_stopping_condition = value |