Skip to content

Commit

Permalink
Added callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-diedrich committed Jan 25, 2024
1 parent b28ce1a commit 5ff362c
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 25 deletions.
106 changes: 92 additions & 14 deletions sccoral/model/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
import anndata as ad
import pandas as pd
from scvi import REGISTRY_KEYS
from scvi.autotune import Tunable, TunableMixin
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalJointObsField, CategoricalObsField, LayerField, NumericalJointObsField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import _init_library_size
from scvi.model.base import BaseModelClass

# from scvi.train import TrainingPlan
from scvi.train import TrainRunner
from scvi.utils import setup_anndata_dsp
from torch import inference_mode

from sccoral.module import MODULE
from sccoral.train import _training_plan

logger = logging.getLogger(__name__)


class SCCORAL(BaseModelClass):
class SCCORAL(TunableMixin, BaseModelClass):
"""Single-cell COvariate-informed Regularized variational Autoencoder with Linear Decoder
Parameters
Expand Down Expand Up @@ -72,17 +78,20 @@ class SCCORAL(BaseModelClass):
"""

_module_cls = MODULE
_data_splitter_class = DataSplitter
_training_plan_class = _training_plan
_train_runner_cls = TrainRunner

def __init__(
self,
adata: ad.AnnData,
n_latent: int = 10,
alpha_l1: float = 0.1,
n_hidden: int = 128,
n_layers: int = 1,
dropout_rate: float = 0.1,
alpha_l1: Tunable[float] = 0.1,
n_hidden: Tunable[int] = 128,
n_layers: Tunable[int] = 1,
dropout_rate: Tunable[float] = 0.1,
# TODO dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["nb", "zinb", "poisson"] = "nb",
gene_likelihood: Tunable[Literal["nb", "zinb", "poisson"]] = "nb",
# TODO latent_distribution: Literal["normal", "lognormal"] = "normal",
**model_kwargs,
) -> None:
Expand Down Expand Up @@ -210,14 +219,83 @@ def setup_anndata(
def train(
self,
max_epochs: int = 500,
max_pretraining_epochs: None | int = 500,
accelerator: None | Literal["cpu", "gpu"] = None,
validation_size: None | float = None,
plan_kwargs: None | dict[str, Any] = None,
max_pretraining_epochs: None | int = None,
use_gpu: None | bool = None,
accelerator: None | Literal["cpu", "gpu", "auto"] = "auto",
devices="auto",
train_size: None | float = 0.9,
batch_size: int = 128,
early_stopping_pretraining: bool = True,
early_stopping: bool = True,
**trainer_kwargs: Any,
pretraining: bool = True,
pretraining_kwargs: None | dict[str, Any] = None,
plan_kwargs: None | dict[str, Any] = None,
trainer_kwargs: None | dict[str, Any] = None,
) -> None:
# TODO
pass
"""Train sccoral model
Training is split into pretraining (only training on covariates, frozen z_encoder weights)
and training (unfrozen weights)
max_epochs
Maximum epochs during training
max_pretraining_epochs
Maximum epochs during pretraining. If `None`, same as max_epochs
use_gpu
Whether to use gpu. If `None` automatically detects gpu
accelerator
cpu/gpu/auto: auto automatically detects available devices
devices
If `auto`, automatically detects available devices
train_size
Size of train split (0-1). Rest is validation split
batch_size
Size of minibatches during training
early_stopping
Enable early stopping during training
early_stopping_pretraining
Enable early stopping during pretraining
pretraining
Whether to conduct pretraining
pretraining_kwargs
Additional keyword arguments passed to `sccoral.train.TrainingPlan`
affecting pretraining
plan_kwargs
Training keyword arguments passed to `sccoral.train.TrainingPlan`
trainer_kwargs
Additional keyword arguments passed to `scvi.train.TrainRunner`
"""
# TODO Refactor
if pretraining:
if pretraining_kwargs is None:
pretraining_kwargs = {"early_stopping": True, "early_stopping_patience": 5}
else:
pretraining_kwargs = None

# PRETRAINING
# TRAINING
# PASSED TO pl.Trainer
training_plan = self._training_plan(module=self.module, **pretraining_kwargs, **plan_kwargs)

assert (train_size <= 1) & (train_size > 0)
validation_size = 1 - train_size

data_splitter = self._data_splitter_cls(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
batch_size=batch_size,
use_gpu=use_gpu,
)

# Should be left as is
runner = self._train_runner_cls(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
)

return runner()
23 changes: 12 additions & 11 deletions sccoral/module/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.functional as F
from scvi import REGISTRY_KEYS
from scvi.autotune import Tunable
from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import Encoder, one_hot
Expand Down Expand Up @@ -64,22 +65,22 @@ def __init__(
n_input: int,
categorical_mapping: None | dict[str, int],
continuous_names: None | Iterable,
alpha_l1: float = 0,
alpha_l1: Tunable[float] = 0,
n_batch: int = 0,
n_labels: int = 0, # TODO gene-labels not implemented
n_hidden: int = 128,
n_hidden: Tunable[int] = 128,
n_latent: int = 10,
n_layers: int = 1,
distribution: Literal["normal", "ln"] = "normal",
dropout_rate: int = 0.1,
n_layers: Tunable[int] = 1,
distribution: Tunable[Literal["normal", "ln"]] = "normal",
dropout_rate: Tunable[int] = 0.1,
log_variational: bool = True, # as LSCVI
gene_likelihood: Literal["nb", "zinb", "poisson"] = "nb", # as LSCVI
latent_distribution: Literal["normal", "ln"] = "normal", # as LSCVI
dispersion: Literal["gene", "gene-batch", "gene-cell"] = "gene", # TODO gene-labels not implemented
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "encoder",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
gene_likelihood: Tunable[Literal["nb", "zinb", "poisson"]] = "nb", # as LSCVI
latent_distribution: Tunable[Literal["normal", "ln"]] = "normal", # as LSCVI
dispersion: Tunable[Literal["gene", "gene-batch", "gene-cell"]] = "gene", # TODO gene-labels not implemented
use_batch_norm: Tunable[Literal["encoder", "decoder", "none", "both"]] = "encoder",
use_layer_norm: Tunable[Literal["encoder", "none"]] = "none",
# use_size_factor_key: bool = False, #TODO SKIP
use_observed_lib_size: bool = True, # TODO LSCVI overwrites this flag and uses False
use_observed_lib_size: Tunable[bool] = True, # TODO LSCVI overwrites this flag and uses False
library_log_means: None | np.ndarray = None,
library_log_vars: None | np.ndarray = None,
**vae_kwargs,
Expand Down
1 change: 1 addition & 0 deletions sccoral/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .train import ScCoralTrainingPlan as _training_plan
107 changes: 107 additions & 0 deletions sccoral/train/_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Literal

from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import BaseFinetuning, EarlyStopping
from torch.optim.optimizer import Optimizer


class EarlyStoppingCheck(EarlyStopping):
"""Check if early stopping condition is met but do not interrupt training
Modified `lightning.pytorch.callbacks.EarlyStopping` class that
instead of sending early stopping signal to `Trainer` sets the
parameter `pretraining_early_stopping_condition` in `TrainingPlan`
to true
Parameters
----------
monitor
Which loss to monitor.
min_delta
Definition of converging loss
patience
Number of consequetive epochs to wait until we send a stopping signal
mode
Look for maximum or minimum
check_on_train
Whether to check on training epoch end or validation epoch end. Defaults
to training epoch
**kwargs
Other arguments passed to `lightning.pytorch.callbacks.EarlyStopping`
"""

def __init__(
self,
monitor="reconstruction_loss",
min_delta: float = 0.0,
patience: int = 5,
mode: Literal["max", "min"] = "min",
check_on_train: bool = True,
**kwargs,
):
super().__init__(monitor, min_delta, patience, mode, **kwargs)

self.check_on_train = check_on_train

self.state = {}

def _run_early_stopping_check(self, trainer: Trainer, pl_module: LightningModule):
"""Overwrite method that stops trainer"""
pass

def _check_stopping(self, trainer: Trainer, pl_module: LightningModule):
logs = trainer.callback_metrics
current = logs[self.monitor].squeeze()
should_stop, reason = self._evaluate_stopping_criteria(current)

# Write if model should stop
pl_module.pretraining_early_stopping_condition = should_stop

def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
if not self.check_on_train:
return
self._check_stopping(trainer, pl_module)

def on_validation_end(self, trainer: Trainer, pl_module: LightningModule):
if self.check_on_train:
return
self._check_stopping(trainer, pl_module)


class PretrainingFreezeWeights(BaseFinetuning):
"""Freeze weights of parts of the module until pretraining ends
Parameters
----------
submodule
For which part of the model we would like to freeze the weights during pretraining
n_pretraining_epochs
Maximal number of pretraining epochs
early_stopping
Whether to use `EarlyStoppingCheck` as additional stopping metric
**kwargs
Other keyword arguments passed to `lightning.pytorch.callbacks.BaseFinetuning`
"""

def __init__(self, submodule: str = "z_encoder", n_pretraining_epochs: int = 500, early_stopping=True, **kwargs):
super().__init__(**kwargs)

self.n_pretraining_epochs = n_pretraining_epochs
self.early_stopping = early_stopping
self.submodule = submodule

def freeze_before_training(self, pl_module: LightningModule) -> None:
module = getattr(pl_module.module, self.submodule)
self.freeze(module)

def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer) -> None:
if pl_module.is_pretrained: # skip if pretraining is finished
return
early_stopping_condition = False
if self.early_stopping:
early_stopping_condition = pl_module.pretraining_early_stopping_condition
if (epoch == self.n_pretraining_epochs) or early_stopping_condition:
self.unfreeze_and_add_param_group(
modules=getattr(pl_module.module, self.submodule), optimizer=optimizer, train_bn=True
)
pl_module.is_pretrained = True
37 changes: 37 additions & 0 deletions sccoral/train/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from collections.abc import Iterable
from typing import Callable

import optax
import torch
from scvi.train import TrainingPlan

axOptimizerCreator = Callable[[], optax.GradientTransformation]
TorchOptimizerCreator = Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer]


class ScCoralTrainingPlan(TrainingPlan):
"""Implement custom pretraining procedure for sccoral"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self._is_pretrained = False
self._pretraining_early_stopping_condition = False

@property
def is_pretrained(self):
"""If model is pretrained"""
return self._is_pretrained

@is_pretrained.setter
def is_pretrained(self, value: bool):
self._is_pretrained = value

@property
def pretraining_early_stopping_condition(self):
"""If pretraining should stop"""
return self._pretraining_early_stopping_condition

@pretraining_early_stopping_condition.setter
def pretraining_early_stopping_condition(self, value: bool):
self._pretraining_early_stopping_condition = value

0 comments on commit 5ff362c

Please sign in to comment.