Skip to content

Commit

Permalink
Make training paramters tunable
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-diedrich committed Mar 22, 2024
1 parent 3828452 commit 119dfdd
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions sccoral/train/_trainingplans.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from collections.abc import Iterable
from typing import Callable
from typing import Callable, Literal, Optional

import torch
from scvi.train import TrainingPlan

# Changes after scvi 1.0.4
try:
from scvi.autotune import Tunable
except ImportError:
from scvi._types import Tunable


TorchOptimizerCreator = Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer]


Expand All @@ -14,8 +21,47 @@ class ScCoralTrainingPlan(TrainingPlan):
custom properties `is_pretrained` and `pretraining_early_stopping_condition`
"""

def __init__(self, module, **kwargs):
super().__init__(module, **kwargs)
def __init__(
self,
module,
optimizer: Tunable[Literal["Adam", "AdamW", "Custom"]] = "Adam",
optimizer_creator: Optional[TorchOptimizerCreator] = None,
lr: Tunable[float] = 1e-3,
weight_decay: Tunable[float] = 1e-6,
eps: float = 0.01,
n_steps_kl_warmup: Tunable[int] = None,
n_epochs_kl_warmup: Tunable[int] = 400,
reduce_lr_on_plateau: bool = False,
lr_factor: Tunable[float] = 0.6,
lr_patience: Tunable[int] = 30,
lr_threshold: Tunable[float] = 0.0,
lr_scheduler_metric: Tunable[
Literal["elbo_validation", "reconstruction_loss_validation", "kl_local_validation"]
] = "elbo_validation",
lr_min: float = 0,
max_kl_weight: Tunable[float] = 1.0,
min_kl_weight: Tunable[float] = 0.0,
**loss_kwargs,
):
super().__init__(
module,
optimizer=optimizer,
optimizer_creator=optimizer_creator,
lr=lr,
weight_decay=weight_decay,
eps=eps,
n_steps_kl_warmup=n_steps_kl_warmup,
n_epochs_kl_warmup=n_epochs_kl_warmup,
reduce_lr_on_plateau=reduce_lr_on_plateau,
lr_factor=lr_factor,
lr_patience=lr_patience,
lr_threshold=lr_threshold,
lr_scheduler_metric=lr_scheduler_metric,
lr_min=lr_min,
max_kl_weight=max_kl_weight,
min_kl_weight=min_kl_weight,
**loss_kwargs,
)

self._is_pretrained = False
self._pretraining_early_stopping_condition = False
Expand Down

0 comments on commit 119dfdd

Please sign in to comment.