Skip to content

Commit

Permalink
Include MoG and ConditionalNorms.
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Sep 22, 2023
1 parent 69d6b25 commit 0fa645c
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 21 deletions.
66 changes: 56 additions & 10 deletions src/embedding_scvi/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,38 @@
from torch import nn


class ConditionalBatchNorm2d(nn.Module):
def __init__(self, num_features, num_classes, momentum, eps):
super().__init__()
self.num_features = num_features
self.bn = nn.BatchNorm1d(self.num_features, momentum=momentum, eps=eps, affine=False)
self.embed = nn.Embedding(num_classes, self.num_features * 2)
self.embed.weight.data[:, :self.num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, self.num_features:].zero_() # Initialise bias at 0

def forward(self, x, y):
out = self.bn(x)
gamma, beta = self.embed(y.long().ravel()).chunk(2, 1)
out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features)

return out

class ConditionalLayerNorm(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.ln = nn.LayerNorm(self.num_features, elementwise_affine=False)
self.embed = nn.Embedding(num_classes, self.num_features * 2)
self.embed.weight.data[:, :self.num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, self.num_features:].zero_() # Initialise bias at 0

def forward(self, x, y):
out = self.ln(x)
gamma, beta = self.embed(y.long().ravel()).chunk(2, 1)
out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features)

return out

class MLPBlock(nn.Module):
"""Multi-layer perceptron block.
Expand Down Expand Up @@ -45,7 +77,9 @@ def __init__(
self,
n_in: int,
n_out: int,
cat_dim: int | None = None,
bias: bool = True,
conditional: bool = False,
norm: Literal["batch", "layer"] | None = None,
norm_kwargs: dict | None = None,
activation: Literal["relu", "leaky_relu", "softmax", "softplus"] | None = None,
Expand All @@ -62,16 +96,28 @@ def __init__(
self.dropout = nn.Identity()
self.residual = residual

if norm == "batch":
self.norm = nn.BatchNorm1d(n_out, **self.norm_kwargs)
elif norm == "layer":
self.norm = nn.LayerNorm(n_out, **self.norm_kwargs)
elif norm is not None:
raise InvalidParameterError(
param="norm",
value=norm,
valid=["batch", "layer", None],
)
if conditional:
if norm == "batch":
self.norm = ConditionalBatchNorm2d(n_out, cat_dim, momentum=0.01, eps=0.001)
elif norm == "layer":
self.norm = ConditionalLayerNorm(n_out, cat_dim)
elif norm is not None:
raise InvalidParameterError(
param="norm",
value=norm,
valid=["batch", "layer", None],
)
else:
if norm == "batch":
self.norm = nn.BatchNorm1d(n_out, **self.norm_kwargs)
elif norm == "layer":
self.norm = nn.LayerNorm(n_out, **self.norm_kwargs)
elif norm is not None:
raise InvalidParameterError(
param="norm",
value=norm,
valid=["batch", "layer", None],
)

if activation == "relu":
self.activation = nn.ReLU(**self.activation_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/embedding_scvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, adata: AnnData, **kwargs):
)
self.module = EmbeddingVAE(
n_vars=self.summary_stats.n_vars,
n_labels=self.summary_stats.n_labels,
categorical_covariates=categorical_covariates,
**kwargs,
)
Expand All @@ -63,6 +64,7 @@ def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
labels_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
**kwargs,
):
Expand All @@ -77,6 +79,7 @@ def setup_anndata(
setup_method_args = cls._get_setup_method_args(**locals())
anndata_fields = [
fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
ExtendableCategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
Expand Down
72 changes: 61 additions & 11 deletions src/embedding_scvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def __init__(
self,
n_vars: int,
n_latent: int = 25,
n_labels: int | None = None,
categorical_covariates: list[int] | None = None,
likelihood: str = "zinb",
encoder_kwargs: dict | None = None,
decoder_kwargs: dict | None = None,
prior: str | None = None,
mixture_components: int = 50,
):
super().__init__()

Expand All @@ -53,6 +56,30 @@ def __init__(
self.encoder_kwargs = encoder_kwargs or {}
self.decoder_kwargs = decoder_kwargs or {}

self.prior = prior
if self.prior=='mog':
self.register_buffer(
"u_prior_logits", torch.ones([mixture_components]))
self.register_buffer(
"u_prior_means", torch.randn([n_latent, mixture_components]))
self.register_buffer(
"u_prior_scales", torch.zeros([n_latent, mixture_components]))
elif self.prior=='mog_celltype':
self.register_buffer(
"u_prior_logits", torch.ones([n_labels]))
self.register_buffer(
"u_prior_means", torch.randn([n_latent, n_labels]))
self.register_buffer(
"u_prior_scales", torch.zeros([n_latent, n_labels]))

self.covariates_encoder = nn.Identity()
if self.categorical_covariates is not None:
self.covariates_encoder = ExtendableEmbeddingList(
num_embeddings=self.categorical_covariates,
embedding_dim=self.n_latent,
)


encoder_dist_params = likelihood_to_dist_params("normal")
_encoder_kwargs = {
"n_hidden": 256,
Expand All @@ -62,6 +89,7 @@ def __init__(
"activation": "gelu",
"dropout_rate": 0.1,
"residual": True,
"cat_dim": self.categorical_covariates.num_embeddings,
}
_encoder_kwargs.update(self.encoder_kwargs)
self.encoder = MultiOutputMLP(
Expand All @@ -81,7 +109,9 @@ def __init__(
"activation": "gelu",
"dropout_rate": None,
"residual": True,
"n_out_params": self.categorical_covariates.num_embeddings,
}

_decoder_kwargs.update(self.decoder_kwargs)
self.decoder = MultiOutputMLP(
n_in=self.n_latent,
Expand All @@ -91,13 +121,6 @@ def __init__(
**_decoder_kwargs,
)

self.covariates_encoder = nn.Identity()
if self.categorical_covariates is not None:
self.covariates_encoder = ExtendableEmbeddingList(
num_embeddings=self.categorical_covariates,
embedding_dim=self.n_latent,
)

def get_covariate_embeddings(
self,
covariate_indexes: list[int] | int | None,
Expand All @@ -113,16 +136,19 @@ def get_covariate_embeddings(

def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
covariates = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None)
return {
REGISTRY_KEYS.X_KEY: x,
REGISTRY_KEYS.LABELS_KEY: y,
REGISTRY_KEYS.CAT_COVS_KEY: covariates,
}

@auto_move_data
def inference(
self,
X: torch.Tensor,
y: torch.Tensor | None = None,
extra_categorical_covs: torch.Tensor | None = None,
subset_categorical_covs: int | list[int] | None = None,
):
Expand All @@ -131,7 +157,23 @@ def inference(

posterior_loc, posterior_scale = self.encoder(X)
posterior = dist.Normal(posterior_loc, posterior_scale + 1e-9)
prior = dist.Normal(torch.zeros_like(posterior_loc), torch.ones_like(posterior_scale))

if self.prior=='mog':
cats = dist.Categorical(logits=self.u_prior_logits)
normal_dists = dist.Normal(
self.u_prior_means,
torch.exp(self.u_prior_scales))
prior = dist.MixtureSameFamily(cats, normal_dists)
elif self.prior=='mog_celltype':
label_bias = 10.0 * torch.nn.functional.one_hot(y, self.n_labels) if self.n_labels >= 2 else 0.0
cats = dist.Categorical(logits=self.u_prior_logits + label_bias)
normal_dists = dist.Normal(
self.u_prior_means,
torch.exp(self.u_prior_scales))
prior = dist.MixtureSameFamily(cats, normal_dists)
else:
prior = dist.Normal(torch.zeros_like(posterior_loc), torch.ones_like(posterior_scale))

z = posterior.rsample()

covariates_z = self.covariates_encoder(
Expand Down Expand Up @@ -216,10 +258,18 @@ def loss(
X = tensors[REGISTRY_KEYS.X_KEY]
posterior = inference_outputs[TENSORS_KEYS.QZ_KEY]
prior = inference_outputs[TENSORS_KEYS.PZ_KEY]
likelihood = generative_outputs[TENSORS_KEYS.PX_KEY]

# (n_obs, n_latent) -> (n_obs,)
kl_div = dist.kl_divergence(posterior, prior).sum(dim=-1)
if self.prior=='mog' or self.prior=='mog_celltype':
u = posterior.rsample(sample_shape=(10,))
# (n_obs, n_latent) -> (n_obs,)
kl_z = prior.log_prob(u) - posterior.log_prob(u)
kl_div = kl_z.sum(-1)
else:
# (n_obs, n_latent) -> (n_obs,)
kl_div = dist.kl_divergence(posterior, prior).sum(dim=-1)


likelihood = generative_outputs[TENSORS_KEYS.PX_KEY]
weighted_kl_div = kl_weight * kl_div
# (n_obs, n_vars) -> (n_obs,)
reconstruction_loss = -likelihood.log_prob(X).sum(dim=-1)
Expand Down

0 comments on commit 0fa645c

Please sign in to comment.