From e4887393f1a9b34462301201d337dfeb685a8b39 Mon Sep 17 00:00:00 2001 From: Lucas Diedrich Date: Wed, 22 May 2024 11:22:40 -0400 Subject: [PATCH] Independent logit for covariate-informed and gene-expression-informed dimensions --- sccoral/module/_module.py | 9 +++------ sccoral/nn/_components.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sccoral/module/_module.py b/sccoral/module/_module.py index 3154217..0127576 100644 --- a/sccoral/module/_module.py +++ b/sccoral/module/_module.py @@ -171,7 +171,7 @@ def __init__( n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, - distribution="normal", + distribution=latent_distribution, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, return_dist=False, @@ -202,7 +202,7 @@ def __init__( name = f"encoder_{cat_name}" model = LinearEncoder( - n_levels, 1, latent_distribution="normal", mean_bias=True, var_bias=True + n_levels, 1, latent_distribution=latent_distribution, mean_bias=True, var_bias=True ) # Register encoder in class @@ -218,7 +218,7 @@ def __init__( if continuous_names is not None: for con_name, dim in zip(continuous_names, range(n_latent + n_cat, n_latent + n_cat + n_con)): name = f"encoder_{con_name}" - model = LinearEncoder(1, 1, latent_distribution="normal") + model = LinearEncoder(1, 1, latent_distribution=latent_distribution) # Register encoder in class setattr(self, name, model) @@ -338,9 +338,6 @@ def inference(self, x, batch_index, continuous_covariates, categorical_covariate var_z = torch.cat([var_counts, *var_ca, *var_cc], dim=1) z = torch.cat([latent_counts, *latent_ca, *latent_cc], dim=1) - if self.latent_distribution == "ln": - z = F.softmax(z, dim=-1) - qz = Normal(loc=mean_z, scale=torch.sqrt(var_z)) if n_samples > 1: diff --git a/sccoral/nn/_components.py b/sccoral/nn/_components.py index 75a08f5..92b1d0a 100644 --- a/sccoral/nn/_components.py +++ b/sccoral/nn/_components.py @@ -63,7 +63,7 @@ def __init__( self.var_eps = var_eps if latent_distribution == "ln": - self.z_transformation = nn.Softmax(dim=-1) + self.z_transformation = nn.Sigmoid() else: # Identity function self.z_transformation = lambda x: x