Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Nov 11, 2024
1 parent 88c0377 commit c0a8ee5
Showing 4 changed files with 21 additions and 12 deletions.
10 changes: 5 additions & 5 deletions docs/tutorial.ipynb
Original file line number Diff line number Diff line change
@@ -507,7 +507,7 @@
],
"source": [
"sc.pl.umap(\n",
" adata, \n",
" adata,\n",
" color=\"directional_cosine_sim_variance\",\n",
" cmap=\"Greys\",\n",
" vmin=\"p1\",\n",
@@ -537,7 +537,7 @@
" extrapolated_cells_list = []\n",
" for i in track(range(n_samples)):\n",
" with io.StringIO() as buf, redirect_stdout(buf):\n",
" vkey = \"velocities_velovi_{i}\".format(i=i)\n",
" vkey = f\"velocities_velovi_{i}\"\n",
" v = vae.get_velocity(n_samples=1, velo_statistic=\"mean\")\n",
" adata.layers[vkey] = v\n",
" scv.tl.velocity_graph(adata, vkey=vkey, sqrt_transform=False, approx=True)\n",
@@ -1134,10 +1134,10 @@
],
"source": [
"sc.pl.umap(\n",
" adata, \n",
" adata,\n",
" color=\"directional_cosine_sim_variance_extrinisic\",\n",
" vmin=\"p1\", \n",
" vmax=\"p99\", \n",
" vmin=\"p1\",\n",
" vmax=\"p99\",\n",
")"
]
},
19 changes: 12 additions & 7 deletions velovi/_model.py
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ class VELOVI(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
Use a linear decoder from latent space to time.
**model_kwargs
Keyword args for :class:`~velovi.VELOVAE`
"""

def __init__(
@@ -108,13 +109,8 @@ def __init__(
**model_kwargs,
)
self._model_summary_string = (
"VELOVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
"{}"
).format(
n_hidden,
n_latent,
n_layers,
dropout_rate,
f"VELOVI Model with the following params: \nn_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, dropout_rate: "
f"{dropout_rate}"
)
self.init_params_ = self._get_init_params(locals())

@@ -164,6 +160,7 @@ def train(
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**trainer_kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
user_plan_kwargs = plan_kwargs.copy() if isinstance(plan_kwargs, dict) else {}
plan_kwargs = {"lr": lr, "weight_decay": weight_decay, "optimizer": "AdamW"}
@@ -238,6 +235,7 @@ def get_state_assignment(
-------
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
@@ -342,6 +340,7 @@ def get_latent_time(
-------
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
if indices is None:
@@ -484,6 +483,7 @@ def get_velocity(
-------
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
if indices is None:
@@ -658,6 +658,7 @@ def get_expression_fit(
-------
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)

@@ -813,6 +814,7 @@ def get_gene_likelihood(
-------
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
@@ -919,6 +921,7 @@ def setup_anndata(
Returns
-------
%(returns)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
@@ -969,6 +972,7 @@ def get_permutation_scores(
-------
Tuple of DataFrame and AnnData. DataFrame is genes by cell types with score per cell type.
AnnData is the permutated version of the original AnnData.
"""
adata = self._validate_anndata(adata)
adata_manager = self.get_anndata_manager(adata)
@@ -1092,6 +1096,7 @@ def _directional_statistics_per_cell(
----------
tensor
Shape of samples by genes for a given cell.
"""
n_samples = tensor.shape[0]
# over samples axis
3 changes: 3 additions & 0 deletions velovi/_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main module."""

from typing import Callable, Iterable, Literal, Optional

import numpy as np
@@ -44,6 +45,7 @@ class DecoderVELOVI(nn.Module):
Whether to use layer norm in layers
linear_decoder
Whether to use linear decoder for time
"""

def __init__(
@@ -183,6 +185,7 @@ class VELOVAE(BaseModuleClass):
var_activation
Callable used to ensure positivity of the variational distributions' variance.
When `None`, defaults to `torch.exp`.
"""

def __init__(
1 change: 1 addition & 0 deletions velovi/_utils.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@ def preprocess_data(
Returns
-------
Preprocessed adata.
"""
if min_max_scale:
scaler = MinMaxScaler()

0 comments on commit c0a8ee5

Please sign in to comment.