Skip to content

Commit

Permalink
Merge pull request #21 from lucas-diedrich/experimental
Browse files Browse the repository at this point in the history
Independent logit for covariate-informed and gene-expression-informed dimensions
  • Loading branch information
lucas-diedrich authored May 22, 2024
2 parents e679b1b + e488739 commit a194b23
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions sccoral/module/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=dropout_rate,
distribution="normal",
distribution=latent_distribution,
use_batch_norm=self.use_batch_norm_encoder,
use_layer_norm=self.use_layer_norm_encoder,
return_dist=False,
Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__(
name = f"encoder_{cat_name}"

model = LinearEncoder(
n_levels, 1, latent_distribution="normal", mean_bias=True, var_bias=True
n_levels, 1, latent_distribution=latent_distribution, mean_bias=True, var_bias=True
)

# Register encoder in class
Expand All @@ -218,7 +218,7 @@ def __init__(
if continuous_names is not None:
for con_name, dim in zip(continuous_names, range(n_latent + n_cat, n_latent + n_cat + n_con)):
name = f"encoder_{con_name}"
model = LinearEncoder(1, 1, latent_distribution="normal")
model = LinearEncoder(1, 1, latent_distribution=latent_distribution)

# Register encoder in class
setattr(self, name, model)
Expand Down Expand Up @@ -338,9 +338,6 @@ def inference(self, x, batch_index, continuous_covariates, categorical_covariate
var_z = torch.cat([var_counts, *var_ca, *var_cc], dim=1)
z = torch.cat([latent_counts, *latent_ca, *latent_cc], dim=1)

if self.latent_distribution == "ln":
z = F.softmax(z, dim=-1)

qz = Normal(loc=mean_z, scale=torch.sqrt(var_z))

if n_samples > 1:
Expand Down
2 changes: 1 addition & 1 deletion sccoral/nn/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.var_eps = var_eps

if latent_distribution == "ln":
self.z_transformation = nn.Softmax(dim=-1)
self.z_transformation = nn.Sigmoid()
else:
# Identity function
self.z_transformation = lambda x: x
Expand Down

0 comments on commit a194b23

Please sign in to comment.