diff --git a/tests/test_model.py b/tests/test_model.py index 74d20df..fb181c9 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -78,12 +78,12 @@ def test_pretraining_early_stopping(adata, pretraining_early_stopping, requires_ assert model.module.z_encoder.encoder.fc_layers[0][0].weight.requires_grad == requires_grad -@pytest.fixture(scope="module") -def basic_train(adata): +@pytest.fixture(scope="module", params=["normal", "ln"]) +def basic_train(adata, request): SCCORAL.setup_anndata( adata, categorical_covariates="categorical_covariate", continuous_covariates="continuous_covariate" ) - model = SCCORAL(adata, n_latent=5) + model = SCCORAL(adata, n_latent=5, latent_distribution=request.param) model.train(max_epochs=20, accelerator="cpu") return model