Skip to content

Commit

Permalink
fix: replace use_gpu with accelerator and devices (#24)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: users should now pass in `accelerator` and `devices` to
`train` instead of `use_gpu`, in line with the changes introduced in
scvi-tools 1.1.
  • Loading branch information
martinkim0 authored Mar 11, 2024
1 parent a8090d6 commit b9a5cfa
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions velovi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from scvi.dataloaders import DataSplitter
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.train import TrainingPlan, TrainRunner
from scvi.utils._docstrings import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp, setup_anndata_dsp

from ._constants import REGISTRY_KEYS
from ._module import VELOVAE
Expand Down Expand Up @@ -118,12 +118,14 @@ def __init__(
)
self.init_params_ = self._get_init_params(locals())

@devices_dsp.dedent
def train(
self,
max_epochs: Optional[int] = 500,
lr: float = 1e-2,
weight_decay: float = 1e-2,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, list[int], str] = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
batch_size: int = 256,
Expand All @@ -143,9 +145,8 @@ def train(
Learning rate for optimization
weight_decay
Weight decay for optimization
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
%(param_accelerator)s
%(param_devices)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Expand Down Expand Up @@ -189,7 +190,8 @@ def train(
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
**trainer_kwargs,
)
return runner()
Expand Down

0 comments on commit b9a5cfa

Please sign in to comment.