Skip to content

Commit

Permalink
Fixed code. Outputs only NaN
Browse files Browse the repository at this point in the history
  • Loading branch information
canergen committed Sep 23, 2023
1 parent 0fa645c commit 2ef0326
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
6 changes: 5 additions & 1 deletion src/embedding_scvi/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __init__(
self,
n_in: int,
n_out: int,
cat_dim: int | None = None,
bias: bool = True,
cat_dim: int | None = None,
conditional: bool = False,
norm: Literal["batch", "layer"] | None = None,
norm_kwargs: dict | None = None,
Expand Down Expand Up @@ -255,6 +255,7 @@ def __init__(
n_hidden: int,
n_layers: int,
bias: bool = True,
cat_dim: int | None = None,
norm: str | None = None,
norm_kwargs: dict | None = None,
activation: str | None = None,
Expand All @@ -277,6 +278,7 @@ def __init__(
n_in=n_in,
n_out=n_out,
bias=bias,
cat_dim=cat_dim,
norm=norm,
norm_kwargs=norm_kwargs,
activation=activation,
Expand Down Expand Up @@ -349,6 +351,7 @@ def __init__(
n_hidden: int,
n_layers: int,
bias: bool = True,
cat_dim: int | None = None,
norm: str | None = None,
norm_kwargs: dict | None = None,
activation: str | None = None,
Expand All @@ -366,6 +369,7 @@ def __init__(
n_hidden=n_hidden,
n_layers=n_layers,
bias=bias,
cat_dim=cat_dim,
norm=norm,
norm_kwargs=norm_kwargs,
activation=activation,
Expand Down
11 changes: 5 additions & 6 deletions src/embedding_scvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
embedding_dim=self.n_latent,
)


encoder_dist_params = likelihood_to_dist_params("normal")
_encoder_kwargs = {
"n_hidden": 256,
Expand All @@ -89,7 +88,7 @@ def __init__(
"activation": "gelu",
"dropout_rate": 0.1,
"residual": True,
"cat_dim": self.categorical_covariates.num_embeddings,
"cat_dim": self.covariates_encoder.num_embeddings,
}
_encoder_kwargs.update(self.encoder_kwargs)
self.encoder = MultiOutputMLP(
Expand All @@ -109,7 +108,7 @@ def __init__(
"activation": "gelu",
"dropout_rate": None,
"residual": True,
"n_out_params": self.categorical_covariates.num_embeddings,
"cat_dim": self.covariates_encoder.num_embeddings
}

_decoder_kwargs.update(self.decoder_kwargs)
Expand Down Expand Up @@ -140,7 +139,7 @@ def _get_inference_input(self, tensors: dict[str, torch.Tensor]) -> dict:
covariates = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None)
return {
REGISTRY_KEYS.X_KEY: x,
REGISTRY_KEYS.LABELS_KEY: y,
"y": y,
REGISTRY_KEYS.CAT_COVS_KEY: covariates,
}

Expand All @@ -152,8 +151,8 @@ def inference(
extra_categorical_covs: torch.Tensor | None = None,
subset_categorical_covs: int | list[int] | None = None,
):
X = torch.log1p(X)
library_size = torch.log(X.sum(dim=1, keepdim=True))
X = torch.log1p(X)

posterior_loc, posterior_scale = self.encoder(X)
posterior = dist.Normal(posterior_loc, posterior_scale + 1e-9)
Expand All @@ -165,7 +164,7 @@ def inference(
torch.exp(self.u_prior_scales))
prior = dist.MixtureSameFamily(cats, normal_dists)
elif self.prior=='mog_celltype':
label_bias = 10.0 * torch.nn.functional.one_hot(y, self.n_labels) if self.n_labels >= 2 else 0.0
label_bias = 10.0 * torch.nn.functional.one_hot(labels, self.n_labels) if self.n_labels >= 2 else 0.0
cats = dist.Categorical(logits=self.u_prior_logits + label_bias)
normal_dists = dist.Normal(
self.u_prior_means,
Expand Down

0 comments on commit 2ef0326

Please sign in to comment.