-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Cleanup terminology and unused code (#4)
* first pass of making terminology consistent, removing unnecessary and unused code, adding docstrings * tests passing * flake * typo * address comments
- Loading branch information
Showing
9 changed files
with
346 additions
and
442 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
@article{Boyeau2022mrvi, | ||
abstract = {Contemporary single-cell omics technologies have enabled complex experimental designs incorporating hundreds of samples accompanied by detailed information on sample-level conditions. Current approaches for analyzing condition-level heterogeneity in these experiments often rely on a simplification of the data such as an aggregation at the cell-type or cell-state-neighborhood level. Here we present MrVI, a deep generative model that provides sample-sample comparisons at a single-cell resolution, permitting the discovery of subtle sample-specific effects across cell populations. Additionally, the output of MrVI can be used to quantify the association between sample-level metadata and cell state variation. We benchmarked MrVI against conventional meta-analysis procedures on two synthetic datasets and one real dataset with a well-controlled experimental structure. This work introduces a novel approach to understanding sample-level heterogeneity while leveraging the full resolution of single-cell sequencing data.Competing Interest StatementN.Y. is an advisor and/or has equity in Cellarity, Celsius Therapeutics, and Rheos Medicine.}, | ||
author = {Boyeau, Pierre and Hong, Justin and Gayoso, Adam and Jordan, Michael I. and Azizi, Elham and Yosef, Nir}, | ||
doi = {10.1101/2022.10.04.510898}, | ||
elocation-id = {2022.10.04.510898}, | ||
eprint = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898.full.pdf}, | ||
journal = {bioRxiv}, | ||
publisher = {Cold Spring Harbor Laboratory}, | ||
title = {Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics}, | ||
url = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898}, | ||
year = {2022}, | ||
bdsk-url-1 = {https://www.biorxiv.org/content/early/2022/10/06/2022.10.04.510898}, | ||
bdsk-url-2 = {https://doi.org/10.1101/2022.10.04.510898}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
References | ||
========== | ||
**Deep generative modeling for quantifying sample-level heterogeneity in single-cell omics** | ||
Pierre Boyeau*, Justin Hong*, Adam Gayoso, Michael I. Jordan, Elham Azizi, Nir Yosef | ||
bioRxiv 2022. `Link <https://doi.org/10.1101/2022.10.04.510898>`_. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import torch | ||
import torch.nn as nn | ||
from scvi.distributions import NegativeBinomial | ||
from scvi.nn import one_hot | ||
|
||
from ._utils import ResnetFC | ||
|
||
|
||
class ExpActivation(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return torch.exp(x) | ||
|
||
|
||
class DecoderZX(nn.Module): | ||
"""Parameterizes the counts likelihood for the data given the latent variables.""" | ||
|
||
def __init__( | ||
self, | ||
n_in, | ||
n_out, | ||
n_nuisance, | ||
linear_decoder, | ||
n_hidden=128, | ||
activation="softmax", | ||
): | ||
super().__init__() | ||
if activation == "softmax": | ||
activation_ = nn.Softmax(-1) | ||
elif activation == "softplus": | ||
activation_ = nn.Softplus() | ||
elif activation == "exp": | ||
activation_ = ExpActivation() | ||
elif activation == "sigmoid": | ||
activation_ = nn.Sigmoid() | ||
else: | ||
raise ValueError("activation must be one of 'softmax' or 'softplus'") | ||
self.linear_decoder = linear_decoder | ||
self.n_nuisance = n_nuisance | ||
self.n_latent = n_in - n_nuisance | ||
if linear_decoder: | ||
self.amat = nn.Linear(self.n_latent, n_out, bias=False) | ||
self.amat_site = nn.Parameter( | ||
torch.randn(self.n_nuisance, self.n_latent, n_out) | ||
) | ||
self.offsets = nn.Parameter(torch.randn(self.n_nuisance, n_out)) | ||
self.dropout_ = nn.Dropout(0.1) | ||
self.activation_ = activation_ | ||
|
||
else: | ||
self.px_mean = ResnetFC( | ||
n_in=n_in, | ||
n_out=n_out, | ||
n_hidden=n_hidden, | ||
activation=activation_, | ||
) | ||
self.px_r = nn.Parameter(torch.randn(n_out)) | ||
|
||
def forward(self, z, size_factor): | ||
if self.linear_decoder: | ||
nuisance_oh = z[..., -self.n_nuisance :] | ||
z0 = z[..., : -self.n_nuisance] | ||
x1 = self.amat(z0) | ||
|
||
nuisance_ids = torch.argmax(nuisance_oh, -1) | ||
As = self.amat_site[nuisance_ids] | ||
z0_detach = self.dropout_(z0.detach())[..., None] | ||
x2 = (As * z0_detach).sum(-2) | ||
offsets = self.offsets[nuisance_ids] | ||
mu = x1 + x2 + offsets | ||
mu = self.activation_(mu) | ||
else: | ||
mu = self.px_mean(z) | ||
mu = mu * size_factor | ||
return NegativeBinomial(mu=mu, theta=self.px_r.exp()) | ||
|
||
|
||
class LinearDecoderUZ(nn.Module): | ||
def __init__( | ||
self, | ||
n_latent, | ||
n_sample, | ||
n_out, | ||
scaler=False, | ||
scaler_n_hidden=32, | ||
): | ||
super().__init__() | ||
self.n_latent = n_latent | ||
self.n_sample = n_sample | ||
self.n_out = n_out | ||
|
||
self.amat_sample = nn.Parameter(torch.randn(n_sample, self.n_latent, n_out)) | ||
self.offsets = nn.Parameter(torch.randn(n_sample, n_out)) | ||
|
||
self.scaler = None | ||
if scaler: | ||
self.scaler = nn.Sequential( | ||
nn.Linear(n_latent + n_sample, scaler_n_hidden), | ||
nn.LayerNorm(scaler_n_hidden), | ||
nn.ReLU(), | ||
nn.Linear(scaler_n_hidden, 1), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, u, sample_id): | ||
sample_id_ = sample_id.long().squeeze() | ||
As = self.amat_sample[sample_id_] | ||
|
||
u_detach = u.detach()[..., None] | ||
z2 = (As * u_detach).sum(-2) | ||
offsets = self.offsets[sample_id_] | ||
delta = z2 + offsets | ||
if self.scaler is not None: | ||
sample_oh = one_hot(sample_id, self.n_sample) | ||
if u.ndim != sample_oh.ndim: | ||
sample_oh = sample_oh[None].expand(u.shape[0], *sample_oh.shape) | ||
inputs = torch.cat([u.detach(), sample_oh], -1) | ||
delta = delta * self.scaler(inputs) | ||
return u + delta | ||
|
||
|
||
class DecoderUZ(nn.Module): | ||
def __init__( | ||
self, | ||
n_latent, | ||
n_latent_sample, | ||
n_out, | ||
dropout_rate=0.0, | ||
n_layers=1, | ||
n_hidden=128, | ||
): | ||
super().__init__() | ||
self.n_latent = n_latent | ||
self.n_latent_sample = n_latent_sample | ||
self.n_in = n_latent + n_latent_sample | ||
self.n_out = n_out | ||
|
||
arch_mod = self.construct_arch(self.n_in, n_hidden, n_layers, dropout_rate) + [ | ||
nn.Linear(n_hidden, self.n_out, bias=False) | ||
] | ||
self.mod = nn.Sequential(*arch_mod) | ||
|
||
arch_scaler = self.construct_arch( | ||
self.n_latent, n_hidden, n_layers, dropout_rate | ||
) + [nn.Linear(n_hidden, 1)] | ||
self.scaler = nn.Sequential(*arch_scaler) | ||
self.scaler.append(nn.Sigmoid()) | ||
|
||
@staticmethod | ||
def construct_arch(n_inputs, n_hidden, n_layers, dropout_rate): | ||
"""Initializes MLP architecture""" | ||
|
||
block_inputs = [ | ||
nn.Linear(n_inputs, n_hidden), | ||
nn.BatchNorm1d(n_hidden), | ||
nn.Dropout(p=dropout_rate), | ||
nn.ReLU(), | ||
] | ||
|
||
block_inner = n_layers * [ | ||
nn.Linear(n_hidden, n_hidden), | ||
nn.BatchNorm1d(n_hidden), | ||
nn.ReLU(), | ||
] | ||
return block_inputs + block_inner | ||
|
||
def forward(self, u): | ||
u_ = u.clone() | ||
if u_.dim() == 3: | ||
n_samples, n_cells, n_features = u_.shape | ||
u0_ = u_[:, :, : self.n_latent].reshape(-1, self.n_latent) | ||
u_ = u_.reshape(-1, n_features) | ||
pred_ = self.mod(u_).reshape(n_samples, n_cells, -1) | ||
scaler_ = self.scaler(u0_).reshape(n_samples, n_cells, -1) | ||
else: | ||
pred_ = self.mod(u) | ||
scaler_ = self.scaler(u[:, : self.n_latent]) | ||
mean = u[..., : self.n_latent] + scaler_ * pred_ | ||
return mean |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from typing import NamedTuple | ||
|
||
|
||
class _MRVI_REGISTRY_KEYS_NT(NamedTuple): | ||
SAMPLE_KEY: str = "sample" | ||
CATEGORICAL_NUISANCE_KEYS: str = "categorical_nuisance_keys" | ||
|
||
|
||
MRVI_REGISTRY_KEYS = _MRVI_REGISTRY_KEYS_NT() |
Oops, something went wrong.