diff --git a/src/embedding_scvi/_components.py b/src/embedding_scvi/_components.py index 92d38cf..a553656 100644 --- a/src/embedding_scvi/_components.py +++ b/src/embedding_scvi/_components.py @@ -8,6 +8,38 @@ from torch import nn +class ConditionalBatchNorm2d(nn.Module): + def __init__(self, num_features, num_classes, momentum, eps): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(self.num_features, momentum=momentum, eps=eps, affine=False) + self.embed = nn.Embedding(num_classes, self.num_features * 2) + self.embed.weight.data[:, :self.num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, self.num_features:].zero_() # Initialise bias at 0 + + def forward(self, x, y): + out = self.bn(x) + gamma, beta = self.embed(y.long().ravel()).chunk(2, 1) + out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features) + + return out + +class ConditionalLayerNorm(nn.Module): + def __init__(self, num_features, num_classes): + super().__init__() + self.num_features = num_features + self.ln = nn.LayerNorm(self.num_features, elementwise_affine=False) + self.embed = nn.Embedding(num_classes, self.num_features * 2) + self.embed.weight.data[:, :self.num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) + self.embed.weight.data[:, self.num_features:].zero_() # Initialise bias at 0 + + def forward(self, x, y): + out = self.ln(x) + gamma, beta = self.embed(y.long().ravel()).chunk(2, 1) + out = gamma.view(-1, self.num_features) * out + beta.view(-1, self.num_features) + + return out + class MLPBlock(nn.Module): """Multi-layer perceptron block. @@ -45,7 +77,9 @@ def __init__( self, n_in: int, n_out: int, + cat_dim: int | None = None, bias: bool = True, + conditional: bool = False, norm: Literal["batch", "layer"] | None = None, norm_kwargs: dict | None = None, activation: Literal["relu", "leaky_relu", "softmax", "softplus"] | None = None, @@ -62,16 +96,28 @@ def __init__( self.dropout = nn.Identity() self.residual = residual - if norm == "batch": - self.norm = nn.BatchNorm1d(n_out, **self.norm_kwargs) - elif norm == "layer": - self.norm = nn.LayerNorm(n_out, **self.norm_kwargs) - elif norm is not None: - raise InvalidParameterError( - param="norm", - value=norm, - valid=["batch", "layer", None], - ) + if conditional: + if norm == "batch": + self.norm = ConditionalBatchNorm2d(n_out, cat_dim, momentum=0.01, eps=0.001) + elif norm == "layer": + self.norm = ConditionalLayerNorm(n_out, cat_dim) + elif norm is not None: + raise InvalidParameterError( + param="norm", + value=norm, + valid=["batch", "layer", None], + ) + else: + if norm == "batch": + self.norm = nn.BatchNorm1d(n_out, **self.norm_kwargs) + elif norm == "layer": + self.norm = nn.LayerNorm(n_out, **self.norm_kwargs) + elif norm is not None: + raise InvalidParameterError( + param="norm", + value=norm, + valid=["batch", "layer", None], + ) if activation == "relu": self.activation = nn.ReLU(**self.activation_kwargs) diff --git a/src/embedding_scvi/_model.py b/src/embedding_scvi/_model.py index a5dd464..ee50500 100644 --- a/src/embedding_scvi/_model.py +++ b/src/embedding_scvi/_model.py @@ -38,6 +38,7 @@ def __init__(self, adata: AnnData, **kwargs): ) self.module = EmbeddingVAE( n_vars=self.summary_stats.n_vars, + n_labels=self.summary_stats.n_labels, categorical_covariates=categorical_covariates, **kwargs, ) @@ -63,6 +64,7 @@ def setup_anndata( cls, adata: AnnData, layer: str | None = None, + labels_key: str | None = None, categorical_covariate_keys: list[str] | None = None, **kwargs, ): @@ -77,6 +79,7 @@ def setup_anndata( setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ExtendableCategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) diff --git a/src/embedding_scvi/_module.py b/src/embedding_scvi/_module.py index 88f8ae5..20c9697 100644 --- a/src/embedding_scvi/_module.py +++ b/src/embedding_scvi/_module.py @@ -39,10 +39,13 @@ def __init__( self, n_vars: int, n_latent: int = 25, + n_labels: int | None = None, categorical_covariates: list[int] | None = None, likelihood: str = "zinb", encoder_kwargs: dict | None = None, decoder_kwargs: dict | None = None, + prior: str | None = None, + mixture_components: int = 50, ): super().__init__() @@ -53,6 +56,30 @@ def __init__( self.encoder_kwargs = encoder_kwargs or {} self.decoder_kwargs = decoder_kwargs or {} + self.prior = prior + if self.prior=='mog': + self.register_buffer( + "u_prior_logits", torch.ones([mixture_components])) + self.register_buffer( + "u_prior_means", torch.randn([n_latent, mixture_components])) + self.register_buffer( + "u_prior_scales", torch.zeros([n_latent, mixture_components])) + elif self.prior=='mog_celltype': + self.register_buffer( + "u_prior_logits", torch.ones([n_labels])) + self.register_buffer( + "u_prior_means", torch.randn([n_latent, n_labels])) + self.register_buffer( + "u_prior_scales", torch.zeros([n_latent, n_labels])) + + self.covariates_encoder = nn.Identity() + if self.categorical_covariates is not None: + self.covariates_encoder = ExtendableEmbeddingList( + num_embeddings=self.categorical_covariates, + embedding_dim=self.n_latent, + ) + + encoder_dist_params = likelihood_to_dist_params("normal") _encoder_kwargs = { "n_hidden": 256, @@ -62,6 +89,7 @@ def __init__( "activation": "gelu", "dropout_rate": 0.1, "residual": True, + "cat_dim": self.categorical_covariates.num_embeddings, } _encoder_kwargs.update(self.encoder_kwargs) self.encoder = MultiOutputMLP( @@ -81,7 +109,9 @@ def __init__( "activation": "gelu", "dropout_rate": None, "residual": True, + "n_out_params": self.categorical_covariates.num_embeddings, } + _decoder_kwargs.update(self.decoder_kwargs) self.decoder = MultiOutputMLP( n_in=self.n_latent, @@ -91,13 +121,6 @@ def __init__( **_decoder_kwargs, ) - self.covariates_encoder = nn.Identity() - if self.categorical_covariates is not None: - self.covariates_encoder = ExtendableEmbeddingList( - num_embeddings=self.categorical_covariates, - embedding_dim=self.n_latent, - ) - def get_covariate_embeddings( self, covariate_indexes: list[int] | int | None, @@ -113,9 +136,11 @@ def get_covariate_embeddings( def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict: x = tensors[REGISTRY_KEYS.X_KEY] + y = tensors[REGISTRY_KEYS.LABELS_KEY] covariates = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None) return { REGISTRY_KEYS.X_KEY: x, + REGISTRY_KEYS.LABELS_KEY: y, REGISTRY_KEYS.CAT_COVS_KEY: covariates, } @@ -123,6 +148,7 @@ def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict: def inference( self, X: torch.Tensor, + y: torch.Tensor | None = None, extra_categorical_covs: torch.Tensor | None = None, subset_categorical_covs: int | list[int] | None = None, ): @@ -131,7 +157,23 @@ def inference( posterior_loc, posterior_scale = self.encoder(X) posterior = dist.Normal(posterior_loc, posterior_scale + 1e-9) - prior = dist.Normal(torch.zeros_like(posterior_loc), torch.ones_like(posterior_scale)) + + if self.prior=='mog': + cats = dist.Categorical(logits=self.u_prior_logits) + normal_dists = dist.Normal( + self.u_prior_means, + torch.exp(self.u_prior_scales)) + prior = dist.MixtureSameFamily(cats, normal_dists) + elif self.prior=='mog_celltype': + label_bias = 10.0 * torch.nn.functional.one_hot(y, self.n_labels) if self.n_labels >= 2 else 0.0 + cats = dist.Categorical(logits=self.u_prior_logits + label_bias) + normal_dists = dist.Normal( + self.u_prior_means, + torch.exp(self.u_prior_scales)) + prior = dist.MixtureSameFamily(cats, normal_dists) + else: + prior = dist.Normal(torch.zeros_like(posterior_loc), torch.ones_like(posterior_scale)) + z = posterior.rsample() covariates_z = self.covariates_encoder( @@ -216,10 +258,18 @@ def loss( X = tensors[REGISTRY_KEYS.X_KEY] posterior = inference_outputs[TENSORS_KEYS.QZ_KEY] prior = inference_outputs[TENSORS_KEYS.PZ_KEY] - likelihood = generative_outputs[TENSORS_KEYS.PX_KEY] - # (n_obs, n_latent) -> (n_obs,) - kl_div = dist.kl_divergence(posterior, prior).sum(dim=-1) + if self.prior=='mog' or self.prior=='mog_celltype': + u = posterior.rsample(sample_shape=(10,)) + # (n_obs, n_latent) -> (n_obs,) + kl_z = prior.log_prob(u) - posterior.log_prob(u) + kl_div = kl_z.sum(-1) + else: + # (n_obs, n_latent) -> (n_obs,) + kl_div = dist.kl_divergence(posterior, prior).sum(dim=-1) + + + likelihood = generative_outputs[TENSORS_KEYS.PX_KEY] weighted_kl_div = kl_weight * kl_div # (n_obs, n_vars) -> (n_obs,) reconstruction_loss = -likelihood.log_prob(X).sum(dim=-1)